Upload modeling_OneChart.py
Browse files- modeling_OneChart.py +26 -12
modeling_OneChart.py
CHANGED
|
@@ -393,8 +393,10 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
|
|
| 393 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 394 |
|
| 395 |
def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
|
| 396 |
-
|
| 397 |
-
device
|
|
|
|
|
|
|
| 398 |
def list_json_value(json_dict):
|
| 399 |
rst_str = []
|
| 400 |
sort_flag = True
|
|
@@ -456,17 +458,29 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
|
|
| 456 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 457 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 458 |
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
output_ids = self.generate(
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
|
| 471 |
outputs = outputs.replace("<Number>", "")
|
| 472 |
outputs = outputs.strip()
|
|
|
|
| 393 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 394 |
|
| 395 |
def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
|
| 396 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 397 |
+
# dtype = torch.bfloat16 if device=="cuda" else next(self.get_model().parameters()).dtype
|
| 398 |
+
dtype=torch.float16 if device=="cuda" else torch.float32
|
| 399 |
+
# print(device, dtype)
|
| 400 |
def list_json_value(json_dict):
|
| 401 |
rst_str = []
|
| 402 |
sort_flag = True
|
|
|
|
| 458 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 459 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 460 |
|
| 461 |
+
if device=='cuda':
|
| 462 |
+
with torch.autocast(device, dtype=dtype):
|
| 463 |
+
output_ids = self.generate(
|
| 464 |
+
input_ids,
|
| 465 |
+
images=[image_tensor_1.unsqueeze(0)],
|
| 466 |
+
do_sample=False,
|
| 467 |
+
num_beams = 1,
|
| 468 |
+
# no_repeat_ngram_size = 20,
|
| 469 |
+
# streamer=streamer,
|
| 470 |
+
max_new_tokens=4096,
|
| 471 |
+
stopping_criteria=[stopping_criteria]
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
output_ids = self.generate(
|
| 475 |
+
input_ids,
|
| 476 |
+
images=[image_tensor_1.unsqueeze(0)],
|
| 477 |
+
do_sample=False,
|
| 478 |
+
num_beams = 1,
|
| 479 |
+
# no_repeat_ngram_size = 20,
|
| 480 |
+
# streamer=streamer,
|
| 481 |
+
max_new_tokens=4096,
|
| 482 |
+
stopping_criteria=[stopping_criteria]
|
| 483 |
+
)
|
| 484 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
|
| 485 |
outputs = outputs.replace("<Number>", "")
|
| 486 |
outputs = outputs.strip()
|