disco_xform_utils.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import torch, torchvision
  2. import py3d_tools as p3d
  3. import midas_utils
  4. from PIL import Image
  5. import numpy as np
  6. import sys, math
  7. try:
  8. from infer import InferenceHelper
  9. except:
  10. print("disco_xform_utils.py failed to import InferenceHelper. Please ensure that AdaBins directory is in the path (i.e. via sys.path.append('./AdaBins') or other means).")
  11. sys.exit()
  12. MAX_ADABINS_AREA = 500000
  13. MIN_ADABINS_AREA = 448*448
  14. @torch.no_grad()
  15. def transform_image_3d(img_filepath, midas_model, midas_transform, device, rot_mat=torch.eye(3).unsqueeze(0), translate=(0.,0.,-0.04), near=2000, far=20000, fov_deg=60, padding_mode='border', sampling_mode='bicubic', midas_weight = 0.3,spherical=False):
  16. img_pil = Image.open(open(img_filepath, 'rb')).convert('RGB')
  17. w, h = img_pil.size
  18. image_tensor = torchvision.transforms.functional.to_tensor(img_pil).to(device)
  19. use_adabins = midas_weight < 1.0
  20. if use_adabins:
  21. # AdaBins
  22. """
  23. predictions using nyu dataset
  24. """
  25. print("Running AdaBins depth estimation implementation...")
  26. infer_helper = InferenceHelper(dataset='nyu')
  27. image_pil_area = w*h
  28. if image_pil_area > MAX_ADABINS_AREA:
  29. scale = math.sqrt(MAX_ADABINS_AREA) / math.sqrt(image_pil_area)
  30. depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.LANCZOS) # LANCZOS is supposed to be good for downsampling.
  31. elif image_pil_area < MIN_ADABINS_AREA:
  32. scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area)
  33. depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.BICUBIC)
  34. else:
  35. depth_input = img_pil
  36. try:
  37. _, adabins_depth = infer_helper.predict_pil(depth_input)
  38. if image_pil_area != MAX_ADABINS_AREA:
  39. adabins_depth = torchvision.transforms.functional.resize(torch.from_numpy(adabins_depth), image_tensor.shape[-2:], interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC).squeeze().to(device)
  40. else:
  41. adabins_depth = torch.from_numpy(adabins_depth).squeeze().to(device)
  42. adabins_depth_np = adabins_depth.cpu().numpy()
  43. except:
  44. pass
  45. torch.cuda.empty_cache()
  46. # MiDaS
  47. img_midas = midas_utils.read_image(img_filepath)
  48. img_midas_input = midas_transform({"image": img_midas})["image"]
  49. midas_optimize = True
  50. # MiDaS depth estimation implementation
  51. print("Running MiDaS depth estimation implementation...")
  52. sample = torch.from_numpy(img_midas_input).float().to(device).unsqueeze(0)
  53. if midas_optimize==True and device == torch.device("cuda"):
  54. sample = sample.to(memory_format=torch.channels_last)
  55. sample = sample.half()
  56. prediction_torch = midas_model.forward(sample)
  57. prediction_torch = torch.nn.functional.interpolate(
  58. prediction_torch.unsqueeze(1),
  59. size=img_midas.shape[:2],
  60. mode="bicubic",
  61. align_corners=False,
  62. ).squeeze()
  63. prediction_np = prediction_torch.clone().cpu().numpy()
  64. print("Finished depth estimation.")
  65. torch.cuda.empty_cache()
  66. # MiDaS makes the near values greater, and the far values lesser. Let's reverse that and try to align with AdaBins a bit better.
  67. prediction_np = np.subtract(50.0, prediction_np)
  68. prediction_np = prediction_np / 19.0
  69. if use_adabins:
  70. adabins_weight = 1.0 - midas_weight
  71. depth_map = prediction_np*midas_weight + adabins_depth_np*adabins_weight
  72. else:
  73. depth_map = prediction_np
  74. depth_map = np.expand_dims(depth_map, axis=0)
  75. depth_tensor = torch.from_numpy(depth_map).squeeze().to(device)
  76. pixel_aspect = 1.0 # really.. the aspect of an individual pixel! (so usually 1.0)
  77. persp_cam_old = p3d.FoVPerspectiveCameras(near, far, pixel_aspect, fov=fov_deg, degrees=True, device=device)
  78. persp_cam_new = p3d.FoVPerspectiveCameras(near, far, pixel_aspect, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)
  79. # range of [-1,1] is important to torch grid_sample's padding handling
  80. y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))
  81. z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)
  82. xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)
  83. # Transform the points using pytorch3d. With current functionality, this is overkill and prevents it from working on Windows.
  84. # If you want it to run on Windows (without pytorch3d), then the transforms (and/or perspective if that's separate) can be done pretty easily without it.
  85. xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
  86. xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
  87. offset_xy = xyz_new_cam_xy - xyz_old_cam_xy
  88. # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.
  89. identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)
  90. # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.
  91. coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
  92. offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
  93. if spherical:
  94. spherical_grid = get_spherical_projection(h, w, torch.tensor([0,0], device=device), -0.4,device=device)#align_corners=False
  95. stage_image = torch.nn.functional.grid_sample(image_tensor.add(1/512 - 0.0001).unsqueeze(0), offset_coords_2d, mode=sampling_mode, padding_mode=padding_mode, align_corners=True)
  96. new_image = torch.nn.functional.grid_sample(stage_image, spherical_grid,align_corners=True) #, mode=sampling_mode, padding_mode=padding_mode, align_corners=False)
  97. else:
  98. new_image = torch.nn.functional.grid_sample(image_tensor.add(1/512 - 0.0001).unsqueeze(0), offset_coords_2d, mode=sampling_mode, padding_mode=padding_mode, align_corners=False)
  99. img_pil = torchvision.transforms.ToPILImage()(new_image.squeeze().clamp(0,1.))
  100. torch.cuda.empty_cache()
  101. return img_pil
  102. def get_spherical_projection(H, W, center, magnitude,device):
  103. xx, yy = torch.linspace(-1, 1, W,dtype=torch.float32,device=device), torch.linspace(-1, 1, H,dtype=torch.float32,device=device)
  104. gridy, gridx = torch.meshgrid(yy, xx)
  105. grid = torch.stack([gridx, gridy], dim=-1)
  106. d = center - grid
  107. d_sum = torch.sqrt((d**2).sum(axis=-1))
  108. grid += d * d_sum.unsqueeze(-1) * magnitude
  109. return grid.unsqueeze(0)