Spaces:
Runtime error
Runtime error
| import tqdm | |
| import torch | |
| from einops import rearrange | |
| def scalar_to_batch_tensor(x, batch_size): | |
| return torch.tensor(x).repeat(batch_size) | |
| def parallelize( | |
| fn, | |
| *iterables, | |
| parallel: str = "thread_map", | |
| **kwargs | |
| ): | |
| if parallel == "thread_map": | |
| from tqdm.contrib.concurrent import thread_map | |
| return thread_map( | |
| fn, | |
| *iterables, | |
| **kwargs | |
| ) | |
| elif parallel == "process_map": | |
| from tqdm.contrib.concurrent import process_map | |
| return process_map( | |
| fn, | |
| *iterables, | |
| **kwargs | |
| ) | |
| elif parallel == "single": | |
| return [fn(x) for x in tqdm.tqdm(*iterables)] | |
| else: | |
| raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}") | |
| def codebook_flatten(tokens: torch.Tensor): | |
| """ | |
| flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time) | |
| """ | |
| return rearrange(tokens, "b c t -> b (t c)") | |
| def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None): | |
| """ | |
| unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time) | |
| """ | |
| tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c) | |
| return tokens | |