| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| All utilities related to data handling. |
| """ |
|
|
| from collections.abc import Callable |
| from functools import partial |
|
|
| import datasets |
| import numpy as np |
| from datasets import Dataset, load_dataset |
|
|
|
|
| |
| |
| CHAR_LIMIT = 1300 |
| |
| |
| VALID_SIZE = 50 |
|
|
|
|
| def get_filtered_dataset(*, ds: datasets.Dataset, print_fn: Callable[..., None]) -> Dataset: |
| """Return the filtered dataset, with long queries removed. |
| |
| We determined that 99% of queries have 529 or fewer characters. Characters roughly correspond to tokens, so this is |
| a good proxy. We cannot use tokens directly, as that depends on the tokenizer, which can be different for each |
| model, but we want the same filter for each model. |
| |
| """ |
| char_lengths = [len(f"{q} {r}") for q, r in zip(ds["query"], ds["response"])] |
| idx_filtered = [i for i, length in enumerate(char_lengths) if length <= CHAR_LIMIT] |
| print_fn(f"Filtered dataset: {100 * len(idx_filtered) / len(ds):.1f}% of the original dataset") |
| return ds.select(idx_filtered) |
|
|
|
|
| def get_train_valid_test_datasets( |
| *, tokenizer, query_template: str, print_fn: Callable[..., None] |
| ) -> tuple[Dataset, Dataset, Dataset]: |
| """ |
| Return the indices of the train, valid, and test splits of the dataset. |
| |
| We cannot use ds.train_test_split(..., stratify_by_column="type") as it gives: |
| |
| > ValueError: Stratifying by column is only supported for ClassLabel column, and column type is Value. |
| |
| even after calling ds_filtered.class_encode_column("type"). Thus, using sklearn's StratifiedKFold instead. |
| """ |
| metamath = load_dataset("meta-math/MetaMathQA")["train"] |
| metamath = get_filtered_dataset(ds=metamath, print_fn=print_fn) |
|
|
| |
| gsm8k = load_dataset("openai/gsm8k", "main") |
| gsm8k = gsm8k.rename_columns({"question": "query", "answer": "response"}) |
| gsm8k_train = gsm8k["train"] |
| gsm8k_test = gsm8k["test"] |
|
|
| np.random.seed(0) |
| indices = np.arange(len(gsm8k_train)) |
| np.random.shuffle(indices) |
| idx_valid = indices[:VALID_SIZE] |
|
|
| ds_train = metamath |
| ds_valid = gsm8k_train.select(idx_valid) |
| ds_test = gsm8k_test |
|
|
| print_fn(f"Train size: {len(ds_train)}") |
| print_fn(f"Valid size: {len(ds_valid)}") |
| print_fn(f"Test size: {len(ds_test)}") |
|
|
| tokenize_with_answer_ = partial(tokenize_with_answer, tokenizer=tokenizer, template=query_template) |
| tokenize_wo_answer_ = partial(tokenize_wo_answer, tokenizer=tokenizer, template=query_template) |
| ds_train = ds_train.map(tokenize_with_answer_, batched=True).remove_columns(["type", "query", "original_question"]) |
| ds_valid = ds_valid.map(tokenize_wo_answer_, batched=True).remove_columns(["query"]) |
| ds_test = ds_test.map(tokenize_wo_answer_, batched=True).remove_columns(["query"]) |
|
|
| return ds_train, ds_valid, ds_test |
|
|
|
|
| def tokenize_with_answer(samples, tokenizer, template): |
| queries = [template.format(query=sample) + answer for sample, answer in zip(samples["query"], samples["response"])] |
| tokenized = tokenizer(queries) |
| tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]] |
| tokenized["attention_mask"] = [ |
| input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"] |
| ] |
| return tokenized |
|
|
|
|
| def tokenize_wo_answer(samples, tokenizer, template): |
| queries = [template.format(query=sample) for sample in samples["query"]] |
| tokenized = tokenizer(queries) |
| tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]] |
| tokenized["attention_mask"] = [ |
| input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"] |
| ] |
| return tokenized |
|
|
|
|
| def get_wiki_small(num_samples: int = 100) -> list[str]: |
| |
| ds = load_dataset("HuggingFaceFW/finewiki", split="train", streaming=True) |
| dataset_head = ds.take(num_samples) |
| rows = [row["text"] for row in dataset_head] |
| return rows |
|
|