| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import List, Optional, Union |
|
|
| from tokenizers import Tokenizer as BaseTokenizer |
|
|
| from .aliases import PathOrStr |
| from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection |
| from .exceptions import OLMoConfigurationError |
|
|
| __all__ = ["Tokenizer"] |
|
|
|
|
| class Tokenizer: |
| """ |
| A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. |
| |
| :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. |
| :param eos_token_id: The token ID corresponding to the "end-of-sentence" token. |
| :param truncate_to: Truncate when tokenizing to this number of token IDs. |
| :param truncate_direction: The direction to truncate in. "right" means truncate the tokens |
| on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, |
| this setting has no effect. |
| """ |
|
|
| def __init__( |
| self, |
| base_tokenizer: BaseTokenizer, |
| eos_token_id: int, |
| pad_token_id: Optional[int] = None, |
| truncate_to: Optional[int] = None, |
| truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, |
| ): |
| self.base_tokenizer = base_tokenizer |
| self.base_tokenizer.no_truncation() |
| self.eos_token_id = eos_token_id |
| self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id |
| self.truncate_to = truncate_to |
| self.truncate_direction = TruncationDirection(truncate_direction) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.base_tokenizer.get_vocab_size() |
|
|
| @property |
| def eos_token(self) -> str: |
| return self.decode([self.eos_token_id], skip_special_tokens=False) |
|
|
| @property |
| def pad_token(self) -> str: |
| return self.decode([self.pad_token_id], skip_special_tokens=False) |
|
|
| @classmethod |
| def from_train_config(cls, config: TrainConfig) -> Tokenizer: |
| tokenizer_identifier = config.tokenizer.identifier |
| if Path(tokenizer_identifier).is_file(): |
| tokenizer = cls.from_file( |
| tokenizer_identifier, |
| eos_token_id=config.model.eos_token_id, |
| pad_token_id=config.model.pad_token_id, |
| ) |
| else: |
| tokenizer = cls.from_pretrained( |
| tokenizer_identifier, |
| eos_token_id=config.model.eos_token_id, |
| pad_token_id=config.model.pad_token_id, |
| ) |
| if config.model.vocab_size != tokenizer.vocab_size: |
| raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
| return tokenizer |
|
|
| @classmethod |
| def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: |
| """ |
| Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. |
| |
| :param identifier: The identifier of a model on the Hub that contains a |
| ``tokenizer.json`` file. |
| :param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
| """ |
| base_tokenizer = BaseTokenizer.from_pretrained(identifier) |
| eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
| return cls(base_tokenizer, eos_token_id, **kwargs) |
|
|
| @classmethod |
| def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: |
| """ |
| Initialize a tokenizer from a file. |
| |
| You can create those files with ``BaseTokenizer.save()``. |
| |
| :param filename: The name of a file containing a tokenizer specification. |
| :param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
| """ |
| base_tokenizer = BaseTokenizer.from_file(filename) |
| eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
| return cls(base_tokenizer, eos_token_id, **kwargs) |
|
|
| @classmethod |
| def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: |
| """ |
| Load a tokenizer from a checkpoint. |
| """ |
| from cached_path import cached_path |
|
|
| |
| config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) |
| tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") |
| model_config = ModelConfig.load(config_path, key="model") |
|
|
| |
| if Path(tokenizer_config.identifier).is_file(): |
| tokenizer = cls.from_file( |
| tokenizer_config.identifier, |
| eos_token_id=model_config.eos_token_id, |
| pad_token_id=model_config.pad_token_id, |
| ) |
| else: |
| tokenizer = cls.from_pretrained( |
| tokenizer_config.identifier, |
| eos_token_id=model_config.eos_token_id, |
| pad_token_id=model_config.pad_token_id, |
| ) |
| if model_config.vocab_size != tokenizer.vocab_size: |
| raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
| return tokenizer |
|
|
| def add_special_tokens(self, input_ids: List[int]) -> List[int]: |
| """ |
| Add special tokens in-place (if not already present) to the given token IDs. |
| """ |
| if not input_ids or input_ids[-1] != self.eos_token_id: |
| input_ids.append(self.eos_token_id) |
| return input_ids |
|
|
| def num_special_tokens_to_add(self, is_pair: bool = False) -> int: |
| return 2 if is_pair else 1 |
|
|
| def _truncate( |
| self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection |
| ) -> list[int]: |
| if truncate_to is None or len(input_ids) <= truncate_to: |
| return input_ids |
| elif direction == TruncationDirection.left: |
| return input_ids[len(input_ids) - truncate_to :] |
| else: |
| return input_ids[: -(len(input_ids) - truncate_to)] |
|
|
| def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: |
| """ |
| Encode a string into token IDs. |
| """ |
| return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] |
|
|
| def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: |
| """ |
| Encode a batch of strings into token IDs. |
| """ |
| truncate_to = self.truncate_to |
| if truncate_to is not None and add_special_tokens: |
| truncate_to -= self.num_special_tokens_to_add(False) |
|
|
| batch_encoding = self.base_tokenizer.encode_batch(inputs) |
|
|
| all_input_ids = [] |
| for encoding in batch_encoding: |
| input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) |
| if add_special_tokens: |
| input_ids = self.add_special_tokens(input_ids) |
| all_input_ids.append(input_ids) |
|
|
| return all_input_ids |
|
|
| def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: |
| """ |
| Decode a list of token IDs to a string. |
| """ |
| return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) |
|
|