#!/usr/bin/env python import datasets import importlib import tqdm import transformers import typer def load_config(config_file: str): spec = importlib.util.spec_from_file_location("config", config_file) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) return config_module.sources, config_module.tokenizer_name, config_module.prefix def tokenize(batch: dict): if tokenizer: return {"num_tokens": tokenizer(batch["text"], padding="do_not_pad", return_length=True)["length"]} return {"num_tokens": 0} def shard_indices(shard_index): if not isinstance(shard_index, list): shard_index = [shard_index] return shard_index def preprocess_shard(ds: datasets.Dataset, num_shards: int, index: int, num_proc: int): shard = ds.shard(num_shards=num_shards, index=index, contiguous=True) shard = shard.flatten_indices() shard = shard.map(tokenize, batched=True, batch_size=1000, num_proc=num_proc) return shard def preprocess_subset(weights: dict, subsets: list, source: str, src_info: dict, dc: datasets.DownloadConfig, num_proc: int): for key, frac in tqdm.tqdm(weights.items(), desc="Loading train subsets"): uri_template = src_info["uri"] print(f" Loading subset: {key} with fraction 1/{frac} from {uri_template.format(key=key)}") ds = datasets.load_dataset( src_info["format"], data_files=uri_template.format(key=key), split="train", download_config=dc, ) ds = ds.select_columns(["text"]) ds = ds.add_column("source", [source] * len(ds)) ds = ds.add_column("subset", [key] * len(ds)) ds = ds.shuffle(seed=42) dss = [preprocess_shard(ds, int(src_info["shards"]/frac), i, num_proc) for i in shard_indices(src_info["shard_index"])] ds = datasets.concatenate_datasets(dss) ds = ds.cast_column("text", datasets.Value("large_string")) print(f" Finished preprocessing subset: {key} with {sum(ds['num_tokens'])} tokens") subsets.append(ds) def main( config_file: str, num_proc: int = 96, max_retries: int = 10, ): sources, tokenizer_name, prefix = load_config(config_file) global tokenizer tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer_name else None dc = datasets.DownloadConfig(num_proc=num_proc, max_retries=max_retries) train_subsets = [] test_subsets = [] file_name = f"{prefix}-" for source, src_info in sources.items(): print(f"Processing source: {source}") shard_index = src_info["shard_index"] if not isinstance(shard_index, list): shard_index = [shard_index] file_name += f"{source}-{'_'.join(str(s) for s in shard_index)}-of-{src_info['shards']}-" preprocess_subset(src_info["train"], train_subsets, source, src_info, dc, num_proc) preprocess_subset(src_info["test"], test_subsets, source, src_info, dc, num_proc) print("Concatenating train subsets") final_train = datasets.concatenate_datasets(train_subsets) print("Shuffling final train dataset") final_train = final_train.shuffle(seed=42) print("Flattening final train dataset") final_train = final_train.flatten_indices() print("Concatenating test subsets") final_test = datasets.concatenate_datasets(test_subsets) print("Shuffling final test dataset") final_test = final_test.shuffle(seed=42) print("Flattening final test dataset") final_test = final_test.flatten_indices() test_file = f"{file_name}test/{file_name}test.parquet" print(f"Writing final test dataset with {sum(final_test['num_tokens'])} tokens to {test_file}") final_test.to_parquet(test_file) train_file = f"{file_name}train/{file_name}train.parquet" print(f"Writing final train dataset with {sum(final_train['num_tokens'])} tokens to {train_file}") final_train.to_parquet(train_file) if __name__ == "__main__": typer.run(main)