| """Module containing PromptTokenizingStrategy and Prompter classes""" |
|
|
| import abc |
| import copy |
| import functools |
| import logging |
| from typing import Dict, List, Tuple, Union |
|
|
| from transformers import PreTrainedTokenizer |
|
|
| from axolotl.prompters import IGNORE_TOKEN_ID |
|
|
| LOG = logging.getLogger("axolotl") |
|
|
| IGNORE_INDEX = -100 |
| LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" |
| LLAMA_DEFAULT_EOS_TOKEN = "</s>" |
| LLAMA_DEFAULT_BOS_TOKEN = "<s>" |
| LLAMA_DEFAULT_UNK_TOKEN = "<unk>" |
|
|
|
|
| class InvalidDataException(Exception): |
| """ |
| Exception raised when the data is invalid |
| """ |
|
|
|
|
| class PromptTokenizingStrategy(abc.ABC): |
| """ |
| Abstract class for tokenizing strategies |
| """ |
|
|
| def __init__( |
| self, |
| prompter, |
| tokenizer, |
| train_on_inputs: bool = False, |
| sequence_len: int = 2048, |
| ): |
| self.prompter = prompter |
| self.tokenizer: PreTrainedTokenizer = tokenizer |
| self.train_on_inputs = train_on_inputs |
| self.sequence_len = sequence_len |
|
|
| @abc.abstractmethod |
| def tokenize_prompt(self, prompt): |
| pass |
|
|
| @functools.lru_cache(maxsize=128) |
| def _get_user_token(self): |
| try: |
| id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") |
| if isinstance(id_or_ids, (int,)): |
| return id_or_ids |
| except KeyError: |
| pass |
| return False |
|
|
| @functools.lru_cache(maxsize=128) |
| def _get_assistant_token(self): |
| try: |
| id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") |
| if isinstance(id_or_ids, (int,)): |
| return id_or_ids |
| except KeyError: |
| pass |
| return False |
|
|
| def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): |
| result = self.tokenizer( |
| prompt, |
| truncation=True, |
| max_length=self.sequence_len, |
| padding=False, |
| return_tensors=None, |
| ) |
| if ( |
| result["input_ids"][-1] != self.tokenizer.eos_token_id |
| and len(result["input_ids"]) < self.sequence_len |
| and add_eos_token |
| ): |
| result["input_ids"].append(self.tokenizer.eos_token_id) |
| result["attention_mask"].append(1) |
|
|
| if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: |
| result["input_ids"] = result["input_ids"][1:] |
| result["attention_mask"] = result["attention_mask"][1:] |
|
|
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
|
|
| class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for instruction-based prompts. |
| """ |
|
|
| def parse_instruction_fields( |
| self, prompt |
| ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: |
| raise NotImplementedError |
|
|
| def tokenize_prompt(self, prompt): |
| ( |
| instruction, |
| input, |
| response, |
| ) = self.parse_instruction_fields(prompt) |
| user_prompt = next( |
| iter( |
| self.prompter.build_prompt( |
| instruction, |
| input, |
| ) |
| ) |
| ) |
| tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) |
| if not self.train_on_inputs: |
| user_prompt_len = len(tokenized_prompt["input_ids"]) |
| |
| tokenized_prompt["labels"] = [-100] * user_prompt_len |
| tokenized_res_prompt = self._tokenize( |
| response, strip_bos_token=True, add_eos_token=True |
| ) |
| tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] |
| tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] |
| tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] |
|
|
| return tokenized_prompt |
|
|
| def _build_full_prompt( |
| self, instruction, input, response |
| ): |
| return next( |
| iter( |
| self.prompter.build_prompt( |
| instruction, |
| input, |
| response, |
| ) |
| ) |
| ) |
|
|
|
|
| class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Alpaca prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["instruction"], |
| prompt["input"] if "input" in prompt else "", |
| prompt["output"], |
| ) |
|
|
|
|
| class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Alpaca Multiple Choice prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["question"], |
| "\n".join(f'- "{choice}"' for choice in prompt["choices"]), |
| prompt["solution"] if "solution" in prompt else prompt["explanation"], |
| ) |
|
|
|
|
| class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Jeopardy prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["question"], |
| prompt["category"], |
| "what is " + prompt["answer"], |
| ) |
|
|
|
|
| class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for OpenAssistant prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["INSTRUCTION"], |
| "", |
| prompt["RESPONSE"], |
| ) |
|
|
|
|
| class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for SummarizeTLDR prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["article"], |
| "", |
| prompt["summary"], |
| ) |
|
|
|
|
| class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for GPTeacher prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["instruction"], |
| prompt["input"] if "input" in prompt else "", |
| prompt["response"], |
| ) |
|
|
|
|
| class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for NomicGPT4All prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
| return ( |
| prompt["prompt"], |
| "", |
| prompt["response"], |
| ) |
|
|
|
|
| class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Completion prompts. |
| """ |
|
|
| def tokenize_prompt(self, prompt): |
| full_prompt = self._build_full_prompt(prompt["text"], None, None) |
| tokenized_full_prompt = self._tokenize(full_prompt) |
|
|
| return tokenized_full_prompt |
|
|
| def _build_full_prompt( |
| self, instruction, input, response |
| ): |
| return next(iter(self.prompter.build_prompt(instruction, input, response))) |
|
|
|
|
| class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Reflection prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: |
| raise NotImplementedError |
|
|
| def tokenize_prompt(self, prompt): |
| ( |
| instruction, |
| input, |
| output, |
| reflection, |
| corrected, |
| ) = self.parse_instruction_fields(prompt) |
| full_prompt = self._build_full_prompt( |
| instruction, input, output, reflection, corrected |
| ) |
| tokenized_full_prompt = self._tokenize(full_prompt) |
| if not self.train_on_inputs: |
| user_prompt = next( |
| iter( |
| self.prompter.build_prompt( |
| instruction, |
| input, |
| ) |
| ) |
| ) |
| tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) |
| user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
| |
| tokenized_full_prompt["labels"] = [ |
| -100 |
| ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] |
|
|
| return tokenized_full_prompt |
|
|
| def _build_full_prompt( |
| self, instruction, input, output, reflection, corrected |
| ): |
| return next( |
| iter( |
| self.prompter.build_prompt( |
| instruction, |
| input, |
| output, |
| reflection, |
| corrected, |
| ) |
| ) |
| ) |
|
|
| def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): |
| result = self.tokenizer( |
| prompt, |
| truncation=True, |
| max_length=self.sequence_len, |
| padding=False, |
| return_tensors=None, |
| ) |
| if ( |
| result["input_ids"][-1] != self.tokenizer.eos_token_id |
| and len(result["input_ids"]) < self.sequence_len |
| and add_eos_token |
| ): |
| result["input_ids"].append(self.tokenizer.eos_token_id) |
| result["attention_mask"].append(1) |
|
|
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
|
|
| class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for Alpaca Reflection prompts. |
| """ |
|
|
| def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: |
| return ( |
| prompt["instruction"], |
| prompt["input"] if "input" in prompt else "", |
| prompt["output"], |
| prompt["reflection"], |
| prompt["corrected"], |
| ) |
|
|
|
|
| class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): |
| """ |
| Tokenizing strategy for ShareGPT prompts. |
| """ |
|
|
| def get_conversation_thread(self, prompt): |
| return prompt["conversations"] |
|
|
| def tokenize_prompt(self, prompt): |
| result, current_len = tokenize_prompt_default() |
| user_token = self._get_user_token() |
| assistant_token = self._get_assistant_token() |
| try: |
| for _, part in enumerate( |
| self.prompter.build_prompt(self.get_conversation_thread(prompt)) |
| ): |
| if isinstance(part, tuple): |
| if part[0] == "USER:": |
| part = part[0] + part[1] if not user_token else part[1] |
| |
| res = self._tokenize( |
| part.strip(), |
| add_eos_token=False, |
| strip_bos_token=True, |
| ) |
| if user_token: |
| res["input_ids"] = [user_token, *res["input_ids"]] |
| |
| labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
| elif part[0] == "ASSISTANT:": |
| |
| part = part[0] + part[1] if not assistant_token else part[1] |
| |
| res = self._tokenize( |
| part.strip(), |
| add_eos_token=True, |
| strip_bos_token=True, |
| ) |
| if assistant_token: |
| res["input_ids"] = [ |
| assistant_token, |
| *res["input_ids"], |
| ] |
| |
| labels = copy.deepcopy(res["input_ids"]) |
| elif part[0] == "SYSTEM:": |
| part = part[1] |
| |
| res = self._tokenize( |
| part.strip(), add_eos_token=False, strip_bos_token=False |
| ) |
| |
| labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
| else: |
| LOG.warning(f"unhandled role: {part[0]}") |
|
|
| |
| result, current_len = parse_tokenized_to_result( |
| result, |
| current_len, |
| res, |
| labels, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
| return result |
| except (KeyError, AssertionError, IndexError) as err: |
| raise InvalidDataException(str(err)) from err |
|
|
| def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): |
| result = self.tokenizer( |
| prompt, |
| truncation=True, |
| max_length=self.sequence_len, |
| padding=False, |
| return_tensors=None, |
| ) |
| if ( |
| result["input_ids"][-1] != self.tokenizer.eos_token_id |
| and len(result["input_ids"]) < self.sequence_len |
| and add_eos_token |
| ): |
| result["input_ids"].append(self.tokenizer.eos_token_id) |
| result["attention_mask"].append(1) |
|
|
| if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: |
| result["input_ids"] = result["input_ids"][1:] |
| result["attention_mask"] = result["attention_mask"][1:] |
|
|
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
|
|
| def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: |
| """ |
| Returns the default values for the tokenize prompt function |
| """ |
|
|
| result: Dict[str, List[int]] = { |
| "input_ids": [], |
| "attention_mask": [], |
| "labels": [], |
| } |
| current_len = 0 |
| return result, current_len |
|
|
|
|
| def parse_tokenized_to_result( |
| result: Dict[str, List[int]], |
| current_len: int, |
| res: Dict[str, List[int]], |
| labels: List[int], |
| pad_token_id: Union[int, None] = None, |
| ) -> Tuple[Dict[str, List[int]], int]: |
| """ |
| Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result |
| """ |
|
|
| input_ids = res["input_ids"] |
| input_len = len(input_ids) |
| result["input_ids"][current_len : current_len + input_len] = input_ids |
| result["attention_mask"][current_len : current_len + input_len] = [ |
| 1 if x != pad_token_id else 0 for x in input_ids |
| ] |
| result["labels"][current_len : current_len + input_len] = labels |
| current_len += input_len |
|
|
| return result, current_len |
|
|