_meta_registrations.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import functools
  2. import torch
  3. import torch._custom_ops
  4. import torch.library
  5. # Ensure that torch.ops.torchvision is visible
  6. import torchvision.extension # noqa: F401
  7. @functools.lru_cache(None)
  8. def get_meta_lib():
  9. return torch.library.Library("torchvision", "IMPL", "Meta")
  10. def register_meta(op_name, overload_name="default"):
  11. def wrapper(fn):
  12. if torchvision.extension._has_ops():
  13. get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
  14. return fn
  15. return wrapper
  16. @register_meta("roi_align")
  17. def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  18. torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
  19. torch._check(
  20. input.dtype == rois.dtype,
  21. lambda: (
  22. "Expected tensor for input to have the same type as tensor for rois; "
  23. f"but type {input.dtype} does not equal {rois.dtype}"
  24. ),
  25. )
  26. num_rois = rois.size(0)
  27. channels = input.size(1)
  28. return input.new_empty((num_rois, channels, pooled_height, pooled_width))
  29. @register_meta("_roi_align_backward")
  30. def meta_roi_align_backward(
  31. grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
  32. ):
  33. torch._check(
  34. grad.dtype == rois.dtype,
  35. lambda: (
  36. "Expected tensor for grad to have the same type as tensor for rois; "
  37. f"but type {grad.dtype} does not equal {rois.dtype}"
  38. ),
  39. )
  40. return grad.new_empty((batch_size, channels, height, width))
  41. @register_meta("ps_roi_align")
  42. def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
  43. torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
  44. torch._check(
  45. input.dtype == rois.dtype,
  46. lambda: (
  47. "Expected tensor for input to have the same type as tensor for rois; "
  48. f"but type {input.dtype} does not equal {rois.dtype}"
  49. ),
  50. )
  51. channels = input.size(1)
  52. torch._check(
  53. channels % (pooled_height * pooled_width) == 0,
  54. "input channels must be a multiple of pooling height * pooling width",
  55. )
  56. num_rois = rois.size(0)
  57. out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
  58. return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
  59. @register_meta("_ps_roi_align_backward")
  60. def meta_ps_roi_align_backward(
  61. grad,
  62. rois,
  63. channel_mapping,
  64. spatial_scale,
  65. pooled_height,
  66. pooled_width,
  67. sampling_ratio,
  68. batch_size,
  69. channels,
  70. height,
  71. width,
  72. ):
  73. torch._check(
  74. grad.dtype == rois.dtype,
  75. lambda: (
  76. "Expected tensor for grad to have the same type as tensor for rois; "
  77. f"but type {grad.dtype} does not equal {rois.dtype}"
  78. ),
  79. )
  80. return grad.new_empty((batch_size, channels, height, width))
  81. @register_meta("roi_pool")
  82. def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
  83. torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
  84. torch._check(
  85. input.dtype == rois.dtype,
  86. lambda: (
  87. "Expected tensor for input to have the same type as tensor for rois; "
  88. f"but type {input.dtype} does not equal {rois.dtype}"
  89. ),
  90. )
  91. num_rois = rois.size(0)
  92. channels = input.size(1)
  93. out_size = (num_rois, channels, pooled_height, pooled_width)
  94. return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
  95. @register_meta("_roi_pool_backward")
  96. def meta_roi_pool_backward(
  97. grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
  98. ):
  99. torch._check(
  100. grad.dtype == rois.dtype,
  101. lambda: (
  102. "Expected tensor for grad to have the same type as tensor for rois; "
  103. f"but type {grad.dtype} does not equal {rois.dtype}"
  104. ),
  105. )
  106. return grad.new_empty((batch_size, channels, height, width))
  107. @register_meta("ps_roi_pool")
  108. def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
  109. torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
  110. torch._check(
  111. input.dtype == rois.dtype,
  112. lambda: (
  113. "Expected tensor for input to have the same type as tensor for rois; "
  114. f"but type {input.dtype} does not equal {rois.dtype}"
  115. ),
  116. )
  117. channels = input.size(1)
  118. torch._check(
  119. channels % (pooled_height * pooled_width) == 0,
  120. "input channels must be a multiple of pooling height * pooling width",
  121. )
  122. num_rois = rois.size(0)
  123. out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
  124. return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
  125. @register_meta("_ps_roi_pool_backward")
  126. def meta_ps_roi_pool_backward(
  127. grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
  128. ):
  129. torch._check(
  130. grad.dtype == rois.dtype,
  131. lambda: (
  132. "Expected tensor for grad to have the same type as tensor for rois; "
  133. f"but type {grad.dtype} does not equal {rois.dtype}"
  134. ),
  135. )
  136. return grad.new_empty((batch_size, channels, height, width))
  137. @torch._custom_ops.impl_abstract("torchvision::nms")
  138. def meta_nms(dets, scores, iou_threshold):
  139. torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
  140. torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
  141. torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
  142. torch._check(
  143. dets.size(0) == scores.size(0),
  144. lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
  145. )
  146. ctx = torch._custom_ops.get_ctx()
  147. num_to_keep = ctx.create_unbacked_symint()
  148. return dets.new_empty(num_to_keep, dtype=torch.long)
  149. @register_meta("deform_conv2d")
  150. def meta_deform_conv2d(
  151. input,
  152. weight,
  153. offset,
  154. mask,
  155. bias,
  156. stride_h,
  157. stride_w,
  158. pad_h,
  159. pad_w,
  160. dil_h,
  161. dil_w,
  162. n_weight_grps,
  163. n_offset_grps,
  164. use_mask,
  165. ):
  166. out_height, out_width = offset.shape[-2:]
  167. out_channels = weight.shape[0]
  168. batch_size = input.shape[0]
  169. return input.new_empty((batch_size, out_channels, out_height, out_width))
  170. @register_meta("_deform_conv2d_backward")
  171. def meta_deform_conv2d_backward(
  172. grad,
  173. input,
  174. weight,
  175. offset,
  176. mask,
  177. bias,
  178. stride_h,
  179. stride_w,
  180. pad_h,
  181. pad_w,
  182. dilation_h,
  183. dilation_w,
  184. groups,
  185. offset_groups,
  186. use_mask,
  187. ):
  188. grad_input = input.new_empty(input.shape)
  189. grad_weight = weight.new_empty(weight.shape)
  190. grad_offset = offset.new_empty(offset.shape)
  191. grad_mask = mask.new_empty(mask.shape)
  192. grad_bias = bias.new_empty(bias.shape)
  193. return grad_input, grad_weight, grad_offset, grad_mask, grad_bias