| import torch.onnx | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import onnx | |
| from transformers import AutoModel | |
| def export_onnx(example_input: torch.Tensor, | |
| model, | |
| onnx_model_name) -> None: | |
| torch.onnx.export( | |
| model, | |
| example_input, | |
| onnx_model_name, | |
| export_params=False, | |
| opset_version=10, | |
| do_constant_folding=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'input' : { | |
| 0 : 'batch_size' | |
| }, | |
| 'output' : { | |
| 0 : 'batch_size' | |
| } | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| """ | |
| Export LVM-Med (RN50 version) | |
| """ | |
| example_input_rn50 = torch.ones(1, 3, 1024, 1024) | |
| lvmmed_rn50 = AutoModel.from_pretrained('ngctnnnn/lvmmed_rn50') | |
| example_output_rn50 = lvmmed_rn50(example_input_rn50)['pooler_output'] | |
| print(f"Example output for LVM-Med (RN50)'s shape: {example_output_rn50.shape}") | |
| export_onnx(example_input_rn50, lvmmed_rn50, onnx_model_name="onnx_model/lvmmed_rn50.onnx") | |
| """ | |
| Export LVM-Med (ViT) | |
| """ | |
| example_input_vit = torch.ones(1, 3, 224, 224) | |
| lvmmed_vit = AutoModel.from_pretrained('ngctnnnn/lvmmed_vit') | |
| example_output_vit = lvmmed_vit(example_input_vit)['pooler_output'] | |
| print(f"Example output for LVM-Med (RN50)'s shape: {example_output_vit.shape}") | |
| export_onnx(example_input_vit, lvmmed_vit, onnx_model_name="onnx_model/lvmmed_vit.onnx") |