| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| tokenizer = AutoTokenizer.from_pretrained("gg-hf/gemma-2-2b-it") |
| model = AutoModelForCausalLM.from_pretrained("gg-hf/gemma-2-2b-it").to("cuda:1") |
|
|
| model.generation_config.cache_implementation = "static" |
|
|
| model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) |
| messages = [ |
| {"role": "user", "content": "Who are you? Please, answer in pirate-speak."}, |
| ] |
|
|
| inputs = tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ).to("cuda:1") |
|
|
| outputs = model.generate(**inputs, max_new_tokens=256) |
| print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) |