| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | import sys |
| | import time |
| | from dataclasses import dataclass, field |
| | from fractions import Fraction |
| |
|
| | import torch as th |
| | from torch import distributed, nn |
| | from torch.nn.parallel.distributed import DistributedDataParallel |
| |
|
| | from .augment import FlipChannels, FlipSign, Remix, Shift |
| | from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks |
| | from .model import Demucs |
| | from .parser import get_name, get_parser |
| | from .raw import Rawset |
| | from .tasnet import ConvTasNet |
| | from .test import evaluate |
| | from .train import train_model, validate_model |
| | from .utils import human_seconds, load_model, save_model, sizeof_fmt |
| |
|
| |
|
| | @dataclass |
| | class SavedState: |
| | metrics: list = field(default_factory=list) |
| | last_state: dict = None |
| | best_state: dict = None |
| | optimizer: dict = None |
| |
|
| |
|
| | def main(): |
| | parser = get_parser() |
| | args = parser.parse_args() |
| | name = get_name(parser, args) |
| | print(f"Experiment {name}") |
| |
|
| | if args.musdb is None and args.rank == 0: |
| | print( |
| | "You must provide the path to the MusDB dataset with the --musdb flag. " |
| | "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", |
| | file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | eval_folder = args.evals / name |
| | eval_folder.mkdir(exist_ok=True, parents=True) |
| | args.logs.mkdir(exist_ok=True) |
| | metrics_path = args.logs / f"{name}.json" |
| | eval_folder.mkdir(exist_ok=True, parents=True) |
| | args.checkpoints.mkdir(exist_ok=True, parents=True) |
| | args.models.mkdir(exist_ok=True, parents=True) |
| |
|
| | if args.device is None: |
| | device = "cpu" |
| | if th.cuda.is_available(): |
| | device = "cuda" |
| | else: |
| | device = args.device |
| |
|
| | th.manual_seed(args.seed) |
| | |
| | |
| | os.environ["OMP_NUM_THREADS"] = "1" |
| |
|
| | if args.world_size > 1: |
| | if device != "cuda" and args.rank == 0: |
| | print("Error: distributed training is only available with cuda device", file=sys.stderr) |
| | sys.exit(1) |
| | th.cuda.set_device(args.rank % th.cuda.device_count()) |
| | distributed.init_process_group(backend="nccl", |
| | init_method="tcp://" + args.master, |
| | rank=args.rank, |
| | world_size=args.world_size) |
| |
|
| | checkpoint = args.checkpoints / f"{name}.th" |
| | checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" |
| | if args.restart and checkpoint.exists(): |
| | checkpoint.unlink() |
| |
|
| | if args.test: |
| | args.epochs = 1 |
| | args.repeat = 0 |
| | model = load_model(args.models / args.test) |
| | elif args.tasnet: |
| | model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X) |
| | else: |
| | model = Demucs( |
| | audio_channels=args.audio_channels, |
| | channels=args.channels, |
| | context=args.context, |
| | depth=args.depth, |
| | glu=args.glu, |
| | growth=args.growth, |
| | kernel_size=args.kernel_size, |
| | lstm_layers=args.lstm_layers, |
| | rescale=args.rescale, |
| | rewrite=args.rewrite, |
| | sources=4, |
| | stride=args.conv_stride, |
| | upsample=args.upsample, |
| | samplerate=args.samplerate |
| | ) |
| | model.to(device) |
| | if args.show: |
| | print(model) |
| | size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) |
| | print(f"Model size {size}") |
| | return |
| |
|
| | optimizer = th.optim.Adam(model.parameters(), lr=args.lr) |
| |
|
| | try: |
| | saved = th.load(checkpoint, map_location='cpu') |
| | except IOError: |
| | saved = SavedState() |
| | else: |
| | model.load_state_dict(saved.last_state) |
| | optimizer.load_state_dict(saved.optimizer) |
| |
|
| | if args.save_model: |
| | if args.rank == 0: |
| | model.to("cpu") |
| | model.load_state_dict(saved.best_state) |
| | save_model(model, args.models / f"{name}.th") |
| | return |
| |
|
| | if args.rank == 0: |
| | done = args.logs / f"{name}.done" |
| | if done.exists(): |
| | done.unlink() |
| |
|
| | if args.augment: |
| | augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), |
| | Remix(group_size=args.remix_group_size)).to(device) |
| | else: |
| | augment = Shift(args.data_stride) |
| |
|
| | if args.mse: |
| | criterion = nn.MSELoss() |
| | else: |
| | criterion = nn.L1Loss() |
| |
|
| | |
| | |
| | |
| | samples = model.valid_length(args.samples) |
| | print(f"Number of training samples adjusted to {samples}") |
| |
|
| | if args.raw: |
| | train_set = Rawset(args.raw / "train", |
| | samples=samples + args.data_stride, |
| | channels=args.audio_channels, |
| | streams=[0, 1, 2, 3, 4], |
| | stride=args.data_stride) |
| |
|
| | valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) |
| | else: |
| | if not args.metadata.is_file() and args.rank == 0: |
| | build_musdb_metadata(args.metadata, args.musdb, args.workers) |
| | if args.world_size > 1: |
| | distributed.barrier() |
| | metadata = json.load(open(args.metadata)) |
| | duration = Fraction(samples + args.data_stride, args.samplerate) |
| | stride = Fraction(args.data_stride, args.samplerate) |
| | train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), |
| | metadata, |
| | duration=duration, |
| | stride=stride, |
| | samplerate=args.samplerate, |
| | channels=args.audio_channels) |
| | valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), |
| | metadata, |
| | samplerate=args.samplerate, |
| | channels=args.audio_channels) |
| |
|
| | best_loss = float("inf") |
| | for epoch, metrics in enumerate(saved.metrics): |
| | print(f"Epoch {epoch:03d}: " |
| | f"train={metrics['train']:.8f} " |
| | f"valid={metrics['valid']:.8f} " |
| | f"best={metrics['best']:.4f} " |
| | f"duration={human_seconds(metrics['duration'])}") |
| | best_loss = metrics['best'] |
| |
|
| | if args.world_size > 1: |
| | dmodel = DistributedDataParallel(model, |
| | device_ids=[th.cuda.current_device()], |
| | output_device=th.cuda.current_device()) |
| | else: |
| | dmodel = model |
| |
|
| | for epoch in range(len(saved.metrics), args.epochs): |
| | begin = time.time() |
| | model.train() |
| | train_loss = train_model(epoch, |
| | train_set, |
| | dmodel, |
| | criterion, |
| | optimizer, |
| | augment, |
| | batch_size=args.batch_size, |
| | device=device, |
| | repeat=args.repeat, |
| | seed=args.seed, |
| | workers=args.workers, |
| | world_size=args.world_size) |
| | model.eval() |
| | valid_loss = validate_model(epoch, |
| | valid_set, |
| | model, |
| | criterion, |
| | device=device, |
| | rank=args.rank, |
| | split=args.split_valid, |
| | world_size=args.world_size) |
| |
|
| | duration = time.time() - begin |
| | if valid_loss < best_loss: |
| | best_loss = valid_loss |
| | saved.best_state = { |
| | key: value.to("cpu").clone() |
| | for key, value in model.state_dict().items() |
| | } |
| | saved.metrics.append({ |
| | "train": train_loss, |
| | "valid": valid_loss, |
| | "best": best_loss, |
| | "duration": duration |
| | }) |
| | if args.rank == 0: |
| | json.dump(saved.metrics, open(metrics_path, "w")) |
| |
|
| | saved.last_state = model.state_dict() |
| | saved.optimizer = optimizer.state_dict() |
| | if args.rank == 0 and not args.test: |
| | th.save(saved, checkpoint_tmp) |
| | checkpoint_tmp.rename(checkpoint) |
| |
|
| | print(f"Epoch {epoch:03d}: " |
| | f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " |
| | f"duration={human_seconds(duration)}") |
| |
|
| | del dmodel |
| | model.load_state_dict(saved.best_state) |
| | if args.eval_cpu: |
| | device = "cpu" |
| | model.to(device) |
| | model.eval() |
| | evaluate(model, |
| | args.musdb, |
| | eval_folder, |
| | rank=args.rank, |
| | world_size=args.world_size, |
| | device=device, |
| | save=args.save, |
| | split=args.split_valid, |
| | shifts=args.shifts, |
| | workers=args.eval_workers) |
| | model.to("cpu") |
| | save_model(model, args.models / f"{name}.th") |
| | if args.rank == 0: |
| | print("done") |
| | done.write_text("done") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|