_register_onnx_ops.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import sys
  2. import warnings
  3. import torch
  4. from torch.onnx import symbolic_opset11 as opset11
  5. from torch.onnx.symbolic_helper import parse_args
  6. _ONNX_OPSET_VERSION_11 = 11
  7. _ONNX_OPSET_VERSION_16 = 16
  8. BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
  9. @parse_args("v", "v", "f")
  10. def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
  11. boxes = opset11.unsqueeze(g, boxes, 0)
  12. scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
  13. max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
  14. iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
  15. # Cast boxes and scores to float32 in case they are float64 inputs
  16. nms_out = g.op(
  17. "NonMaxSuppression",
  18. g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
  19. g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
  20. max_output_per_class,
  21. iou_threshold,
  22. )
  23. return opset11.squeeze(
  24. g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
  25. )
  26. def _process_batch_indices_for_roi_align(g, rois):
  27. indices = opset11.squeeze(
  28. g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
  29. )
  30. return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
  31. def _process_rois_for_roi_align(g, rois):
  32. return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
  33. def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
  34. if sampling_ratio < 0:
  35. warnings.warn(
  36. "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
  37. "The model will be exported with a sampling_ratio of 0."
  38. )
  39. sampling_ratio = 0
  40. return sampling_ratio
  41. @parse_args("v", "v", "f", "i", "i", "i", "i")
  42. def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  43. batch_indices = _process_batch_indices_for_roi_align(g, rois)
  44. rois = _process_rois_for_roi_align(g, rois)
  45. if aligned:
  46. warnings.warn(
  47. "ROIAlign with aligned=True is only supported in opset >= 16. "
  48. "Please export with opset 16 or higher, or use aligned=False."
  49. )
  50. sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
  51. return g.op(
  52. "RoiAlign",
  53. input,
  54. rois,
  55. batch_indices,
  56. spatial_scale_f=spatial_scale,
  57. output_height_i=pooled_height,
  58. output_width_i=pooled_width,
  59. sampling_ratio_i=sampling_ratio,
  60. )
  61. @parse_args("v", "v", "f", "i", "i", "i", "i")
  62. def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  63. batch_indices = _process_batch_indices_for_roi_align(g, rois)
  64. rois = _process_rois_for_roi_align(g, rois)
  65. coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
  66. sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
  67. return g.op(
  68. "RoiAlign",
  69. input,
  70. rois,
  71. batch_indices,
  72. coordinate_transformation_mode_s=coordinate_transformation_mode,
  73. spatial_scale_f=spatial_scale,
  74. output_height_i=pooled_height,
  75. output_width_i=pooled_width,
  76. sampling_ratio_i=sampling_ratio,
  77. )
  78. @parse_args("v", "v", "f", "i", "i")
  79. def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
  80. roi_pool = g.op(
  81. "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
  82. )
  83. return roi_pool, None
  84. def _register_custom_op():
  85. torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
  86. torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
  87. torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
  88. torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)