Skip to content

Commit

Permalink
onnx export, inference code, need to update serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jayrn2 authored Feb 2, 2024
1 parent a3936ff commit cb7f9bb
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions train_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def main(seed=2022, epoches=50): #500
parser.add_argument('--num_workers', type=int, default=16, metavar='N', help='number of workers for data loader (default: 16)')
parser.add_argument('--loss_name', type=str, default='combo', choices=['weighted_bce', 'dice', 'batch_dice', 'focal','combo','combo_batch', 'combo_mix'], help='set the loss function')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (default: 1e-3)')
# this argument is so that we run the script once to export the model to onnx
parser.add_argument('--export_onnx', action='store_true', help='Export model as ONNX')
args = parser.parse_args()

# Set random seed
Expand Down Expand Up @@ -79,8 +81,28 @@ def main(seed=2022, epoches=50): #500

# Load pretrained models for training with another dataset
# - path for pretrained set (DTD training set)
pre_trained_model_path = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\log\\ablation_data_loader_mosts_valid_group_3\\epoch_2023_10_19_04_30_27_texture.pth' #the one we wanna use for UC merced set next
pre_trained_model_path = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\log\\ablation_data_loader_mosts_valid_group_3\\epoch_2023_09_29_10_42_09_texture.pth' #the one we wanna use for UC merced set next
model.load_state_dict(torch.load(pre_trained_model_path))
model.load_state_dict(torch.load(pre_trained_model_path))

# Exporting the model as an ONNX file:
if args.export_onnx:
# Settin the model to evaluation mode for export
model.eval()
# This is the directory the ONNX files r saved in
onnx_directory = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\ONNX_exports'
onnx_file_path = os.path.join(onnx_directory, 'model.onnx')
# Create the directory if it doesn't exist
if not os.path.exists(onnx_directory):
os.makedirs(onnx_directory)
# Temporary tensor inputs for 'image' and 'patch' to match model inputs seen in ablation_data_loader.py
dummy_image = torch.randn(1, 3, 256, 256)
dummy_patch = torch.randn(1, 3, 256, 256)
# Actually exporting the model in ONNX format
torch.onnx.export(model, (dummy_image, dummy_patch), onnx_file_path, verbose=True, input_names=['image', 'patch'], output_names=['output'])
# kill the script after exporting
return


# CUDA init
Expand Down

0 comments on commit cb7f9bb

Please sign in to comment.