| | |
| |
|
| | import functools |
| |
|
| | import seqio |
| | import tensorflow as tf |
| | import t5.data |
| | from datasets import load_dataset, load_from_disk |
| | from t5.data import postprocessors |
| | from t5.data import preprocessors |
| | from t5.evaluation import metrics |
| | from seqio import FunctionDataSource, utils |
| |
|
| | TaskRegistry = seqio.TaskRegistry |
| |
|
| | vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0) |
| |
|
| | DEFAULT_OUTPUT_FEATURES = { |
| | "inputs": seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True, |
| | required=False), |
| | "targets": seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True) |
| | } |
| |
|
| |
|
| | def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None): |
| | if shuffle: |
| | if seed: |
| | dataset = dataset.shuffle(seed=seed) |
| | else: |
| | dataset = dataset.shuffle() |
| | while True: |
| | for item in dataset[str(split)]: |
| | yield item[column] |
| |
|
| |
|
| | def dataset_fn(split, shuffle_files, seed=None, dataset=None): |
| | return tf.data.Dataset.from_generator( |
| | functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset), |
| | output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name) |
| | ) |
| |
|
| |
|
| | @utils.map_over_dataset |
| | def target_to_key(x, key_map, target_key): |
| | """Assign the value from the dataset to target_key in key_map""" |
| | return {**key_map, target_key: x} |
| |
|
| |
|
| | |
| | dataset_name = "/researchdisk/lm_training_dataset_full" |
| | dataset_params = {"from_disk_path": dataset_name} |
| |
|
| | if "from_disk_path" in dataset_params: |
| | dataset = load_from_disk(dataset_params.get("from_disk_path")) |
| | else: |
| | dataset = load_dataset(**dataset_params) |
| |
|
| | dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows} |
| | TaskRegistry.add( |
| | "pretrain_finnish", |
| | source=seqio.FunctionDataSource( |
| | dataset_fn=functools.partial(dataset_fn, dataset=dataset), |
| | splits=("train", "validation"), |
| | caching_permitted=False, |
| | num_input_examples=dataset_shapes, |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | target_to_key, key_map={ |
| | "inputs": None, |
| | "targets": None, |
| | }, target_key="targets"), |
| | seqio.preprocessors.tokenize, |
| | |
| | preprocessors.span_corruption, |
| | seqio.preprocessors.append_eos_after_trim, |
| | ], |
| | output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
| | metric_fns=[metrics.accuracy] |
| | ) |