uci-uav-forge / uavf_2024

MIT License
2 stars 0 forks source link

Vectorize particle filter #187

Closed EricPedley closed 1 month ago

EricPedley commented 2 months ago

Summary

Speeds up particle filter by around 4x. I think we could still try more performance optimizations but it'd be hard. The slowest thing rn is the measurement function, which has to find vectors on the edges of the bounding spheres and run camera projection on them. The camera projection is like ~20% of the runtime of that, and I think we could speed it up by using quaternions for the camera rotation instead of rotation matrices. Also, the code is somehow slower on the GPU than the CPU, despite the slowest parts being pytorch operations, which suggests to me something is wrong and we could do some debugging of how the data is being moved to GPU and operated on.

EricPedley commented 1 month ago

Benchmarking results from the orin:

forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py 
20it [01:16,  3.81s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s

Total time: 65.4881 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   117                                               @profile
   118                                               def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
   119                                                   '''
   120                                                   `states` is (n, 7)
   121                                                   returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
   122                                                   '''
   123     23260     100939.8      4.3      0.2          cam = CameraModel(self.focal_len_pixels, 
   124     11630      27494.0      2.4      0.0                          [self.resolution[0]/2, self.resolution[1]/2], 
   125     11630     228127.5     19.6      0.3                          cam_pose[1].as_matrix(), 
   126     11630      50272.7      4.3      0.1                          cam_pose[0].reshape(3,1))
   127                                           
   128     11630      44174.0      3.8      0.1          n = states.shape[0]
   129     11630     381818.1     32.8      0.6          positions = states[:, :3]
   130     11630     317663.1     27.3      0.5          radii = states[:, -1]
   131                                           
   132     11630    1312691.4    112.9      2.0          cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
   133     11630     594472.2     51.1      0.9          rays_to_center = positions - cam_position_tensor
   134                                                   
   135     11630      12479.0      1.1      0.0          n_samples = 25
   136     11630   37973416.9   3265.1     58.0          orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
   137                                           
   138     11630    1852451.1    159.3      2.8          pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
   139                                           
   140     11630     296315.5     25.5      0.5          pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
   141                                           
   142                                                   # project points into the camera
   143     11630   16447452.0   1414.2     25.1          projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
   144     11630    1283011.3    110.3      2.0          x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
   145     11630    1186818.7    102.0      1.8          y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
   146     11630    1225458.6    105.4      1.9          x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
   147     11630    1167045.2    100.3      1.8          y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
   148                                           
   149     11630     985989.0     84.8      1.5          return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)

Total time: 14.5859 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: update at line 151

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   151                                               @profile
   152                                               def update(self, cam_pose: tuple[np.ndarray, Rotation], measurement:Measurement):
   153                                                   '''
   154                                                   measurements is a list of 2D integer bounding boxes in pixel coordinates (x1,y1,x2,y2)
   155                                                   '''
   156                                           
   157                                                   # add particles to `samples` that would line up with the measurements
   158                                                   # self.samples.extend(
   159                                                   #     self.gen_samples_from_measurement(cam_pose, measurement.box, 10)
   160                                                   #     )
   161                                           
   162        19   14454575.8 760767.1     99.1          measurements = self.compute_measurements(cam_pose, self.samples)
   163        19      87296.3   4594.5      0.6          likelihoods = self.compute_likelihoods(measurements, measurement.box.to_xyxy())
   164                                           
   165                                                   # resample the particles
   166        19      44032.7   2317.5      0.3          self.resample(likelihoods)
EricPedley commented 1 month ago

Baseline benchmarking without these changes (but still with the plotting part commented-out):

forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py 
20it [00:44,  2.24s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s

Total time: 26.6667 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurement at line 121

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   121                                               @profile
   122                                               def compute_measurement(self, cam_pose: tuple[np.ndarray, Rotation], state: np.ndarray) -> BoundingBox:
   123                                                   '''
   124                                                   `state` is the state of the track, which is a 7 element array
   125                                                   '''
   126                                                   # if behind the camera, return a box with 0 area
   127     73248    3959894.6     54.1     14.8          if np.dot(state[:3] - cam_pose[0], cam_pose[1].apply([0,0,1])) < 0:
   128         1          4.4      4.4      0.0              return BoundingBox(0, 0, 0, 0)
   129                                           
   130    146494     383314.2      2.6      1.4          cam = CameraModel(self.focal_len_pixels, 
   131     73247     101525.0      1.4      0.4                          [self.resolution[0]/2, self.resolution[1]/2], 
   132     73247     472996.9      6.5      1.8                          cam_pose[1].as_matrix(), 
   133     73247     180592.5      2.5      0.7                          cam_pose[0].reshape(3,1))
   134                                           
   135     73247      90192.4      1.2      0.3          state_position = state[:3]
   136     73247      67216.9      0.9      0.3          state_radius = state[-1]
   137                                           
   138     73247     262011.1      3.6      1.0          ray_to_center = state_position - cam_pose[0] 
   139                                                   
   140                                                   # monte carlo to find the circumscribed rectangle around the sphere's projection into the camera
   141                                                   # there's probably a better way to do this but I'm not sure what it is
   142                                                   # I tried a method where we project 4 points on the boundary and fit a 2d ellipse to their projection
   143                                                   # but the ellipse fitting was not working well
   144     73247      29165.9      0.4      0.1          n_samples = 100
   145                                           
   146                                                   # sample points on the sphere
   147     73247    1464999.2     20.0      5.5          random_vector = np.random.randn(3, n_samples)
   148     73247    4822581.3     65.8     18.1          random_vector -= np.dot(random_vector.T, ray_to_center) * np.repeat([ray_to_center / np.linalg.norm(ray_to_center)], n_samples, axis=0).T
   149     73247    2824607.9     38.6     10.6          random_vector = random_vector / np.linalg.norm(random_vector, axis=0) * state_radius
   150                                           
   151                                                   # project points into the camera
   152     73247    7656917.1    104.5     28.7          projected_points = cam.project(state_position.reshape((3,1)) + random_vector)
   153     73247    1395048.9     19.0      5.2          x_min = np.min(projected_points[0])
   154     73247     951478.9     13.0      3.6          x_max = np.max(projected_points[0])
   155     73247     807701.3     11.0      3.0          y_min = np.min(projected_points[1])
   156     73247     776503.6     10.6      2.9          y_max = np.max(projected_points[1])
   157                                           
   158     73247     419950.5      5.7      1.6          return BoundingBox((x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min)

Total time: 22.3501 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: update at line 160

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   160                                               @profile
   161                                               def update(self, cam_pose: tuple[np.ndarray, Rotation], measurement:Measurement):
   162                                                   '''
   163                                                   measurements is a list of 2D integer bounding boxes in pixel coordinates (x1,y1,x2,y2)
   164                                                   '''
   165                                           
   166                                                   # add particles to `samples` that would line up with the measurements
   167                                                   # self.samples.extend(
   168                                                   #     self.gen_samples_from_measurement(cam_pose, measurement.box, 10)
   169                                                   #     )
   170                                           
   171     19019      27968.6      1.5      0.1          for i, particle in enumerate(self.samples):
   172                                           
   173                                                       # compute the likelihood of the particle given the measurement
   174                                                       # by comparing the measurement to the particle's predicted
   175                                                       # measurement
   176     19000    8726033.1    459.3     39.0              predicted_measurement = self.compute_measurement(cam_pose, particle.state).to_xyxy()
   177     19000   12061590.6    634.8     54.0              particle.likelihood = self.compute_likelihood(predicted_measurement, measurement.box.to_xyxy())
   178                                           
   179                                                   # resample the particles
   180        19    1534535.3  80765.0      6.9          self.resample()
EricPedley commented 1 month ago

Seems that the slowest parts are trying to get the vectors orthogonal to the camera look vectors to get the sphere boundaries, and projecting them into the camera. Also, fast inv square root isn't actually faster than 1/torch.sqrt(x) :(

Without fast inv sqrt:

forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py 
20it [01:01,  3.09s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s

Total time: 50.4935 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   117                                               @profile
   118                                               def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
   119                                                   '''
   120                                                   `states` is (n, 7)
   121                                                   returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
   122                                                   '''
   123     23452     106006.8      4.5      0.2          cam = CameraModel(self.focal_len_pixels, 
   124     11726      27706.6      2.4      0.1                          [self.resolution[0]/2, self.resolution[1]/2], 
   125     11726     228439.9     19.5      0.5                          cam_pose[1].as_matrix(), 
   126     11726      51397.0      4.4      0.1                          cam_pose[0].reshape(3,1))
   127                                           
   128     11726      43833.8      3.7      0.1          n = states.shape[0]
   129     11726     417475.7     35.6      0.8          positions = states[:, :3]
   130     11726     347961.7     29.7      0.7          radii = states[:, -1]
   131                                           
   132     11726    1407496.1    120.0      2.8          cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
   133     11726     631097.3     53.8      1.2          rays_to_center = positions - cam_position_tensor
   134                                                   
   135     11726      12367.5      1.1      0.0          n_samples = 25
   136     11726   21668508.7   1847.9     42.9          orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
   137                                           
   138     11726    1967327.5    167.8      3.9          pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
   139                                           
   140     11726     284906.5     24.3      0.6          pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
   141                                           
   142                                                   # project points into the camera
   143     11726   17022874.9   1451.7     33.7          projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
   144     11726    1355292.4    115.6      2.7          x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
   145     11726    1301534.9    111.0      2.6          y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
   146     11726    1335170.3    113.9      2.6          x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
   147     11726    1268481.5    108.2      2.5          y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
   148                                           
   149     11726    1015606.3     86.6      2.0          return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)

Total time: 20.6904 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   109                                           @profile
   110                                           def make_ortho_vectors(v: torch.Tensor, m: int):
   111                                               '''
   112                                               `v` is a (n,3) tensor
   113                                               make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
   114                                               
   115                                               to visualize: imagine each v_i is the vector coinciding 
   116                                               with a lion's face direction, and we wish to make m vectors for the lion's mane.
   117                                           
   118                                               it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
   119                                               with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
   120                                               that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
   121                                           
   122                                               returns a tensor of shape (n,m,3)
   123                                               '''
   124     11726      46500.6      4.0      0.2      n = v.shape[0]
   125     11726    1202476.2    102.5      5.8      thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
   126                                           
   127     11726    1352173.7    115.3      6.5      phi_y = torch.atan2(v[:, 0], v[:, 2])
   128     11726    2734895.5    233.2     13.2      square_sum = v[:,0]**2 + v[:,2]**2
   129     11726    1800831.5    153.6      8.7      inverted = 1/torch.sqrt(square_sum)#fast_inv_sqrt(square_sum)
   130     11726    1439063.6    122.7      7.0      phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
   131                                           
   132     11726     487013.4     41.5      2.4      cos_y = torch.cos(phi_y)
   133     11726     485644.9     41.4      2.3      sin_y = torch.sin(phi_y)
   134     11726     453987.2     38.7      2.2      cos_x = torch.cos(phi_x)
   135     11726     447829.4     38.2      2.2      sin_x = torch.sin(phi_x)
   136                                           
   137                                           
   138     35178    1034150.2     29.4      5.0      R = torch.stack(
   139     23452    1423668.2     60.7      6.9              [cos_y, -sin_y*sin_x, sin_y*cos_x,
   140     11726     575882.8     49.1      2.8              torch.zeros_like(cos_x), cos_x, sin_x,
   141     11726    1732036.2    147.7      8.4              -sin_y, -cos_y*sin_x, cos_y*cos_x]
   142     11726     337977.1     28.8      1.6      ).T.reshape(n,3,3)
   143                                               # (n,3,3)
   144                                           
   145                                           
   146     23452     789558.7     33.7      3.8      vectors = torch.stack(
   147     11726      15118.9      1.3      0.1          [
   148     11726     504119.3     43.0      2.4              torch.cos(thetas), 
   149     11726     466240.2     39.8      2.3              torch.sin(thetas), 
   150     11726     546511.6     46.6      2.6              torch.zeros_like(thetas)
   151                                                   ],
   152                                               ) # (3,m)
   153                                           
   154     11726    2814702.0    240.0     13.6      return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)

With fast inv sqrt:

forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py 
20it [01:05,  3.28s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s

Total time: 54.3264 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   117                                               @profile
   118                                               def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
   119                                                   '''
   120                                                   `states` is (n, 7)
   121                                                   returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
   122                                                   '''
   123     23416     102576.1      4.4      0.2          cam = CameraModel(self.focal_len_pixels, 
   124     11708      26484.2      2.3      0.0                          [self.resolution[0]/2, self.resolution[1]/2], 
   125     11708     228881.3     19.5      0.4                          cam_pose[1].as_matrix(), 
   126     11708      52137.3      4.5      0.1                          cam_pose[0].reshape(3,1))
   127                                           
   128     11708      44564.4      3.8      0.1          n = states.shape[0]
   129     11708     418965.1     35.8      0.8          positions = states[:, :3]
   130     11708     354484.2     30.3      0.7          radii = states[:, -1]
   131                                           
   132     11708    1411121.9    120.5      2.6          cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
   133     11708     634851.9     54.2      1.2          rays_to_center = positions - cam_position_tensor
   134                                                   
   135     11708      11406.4      1.0      0.0          n_samples = 25
   136     11708   25274502.8   2158.7     46.5          orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
   137                                           
   138     11708    1994273.8    170.3      3.7          pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
   139                                           
   140     11708     287269.6     24.5      0.5          pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
   141                                           
   142                                                   # project points into the camera
   143     11708   17174573.5   1466.9     31.6          projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
   144     11708    1356988.1    115.9      2.5          x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
   145     11708    1302896.4    111.3      2.4          y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
   146     11708    1346783.8    115.0      2.5          x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
   147     11708    1274790.4    108.9      2.3          y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
   148                                           
   149     11708    1028859.9     87.9      1.9          return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)

Total time: 4.82129 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: fast_inv_sqrt at line 92

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    92                                           @profile
    93                                           def fast_inv_sqrt(x: torch.Tensor):
    94                                               '''
    95                                               Calculates 1/sqrt(x) really fast.
    96                                               If x is (n,) this will be vectorized too
    97                                           
    98                                               '''
    99     11708      11281.0      1.0      0.2      three_halfs = 1.5
   100     11708     668128.5     57.1     13.9      x2 = x * 0.5
   101     11708      11269.4      1.0      0.2      y = x
   102     11708     161805.5     13.8      3.4      i = x.view(torch.int32)
   103     11708    1532964.6    130.9     31.8      i = 0x5f3759df - (i>>1)
   104     11708     132011.4     11.3      2.7      y = i.view(torch.float32)
   105     11708    2292938.0    195.8     47.6      y = y * (three_halfs - (x2 * y * y))
   106                                           
   107     11708      10894.5      0.9      0.2      return y

Total time: 24.3137 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   109                                           @profile
   110                                           def make_ortho_vectors(v: torch.Tensor, m: int):
   111                                               '''
   112                                               `v` is a (n,3) tensor
   113                                               make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
   114                                               
   115                                               to visualize: imagine each v_i is the vector coinciding 
   116                                               with a lion's face direction, and we wish to make m vectors for the lion's mane.
   117                                           
   118                                               it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
   119                                               with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
   120                                               that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
   121                                           
   122                                               returns a tensor of shape (n,m,3)
   123                                               '''
   124     11708      47894.3      4.1      0.2      n = v.shape[0]
   125     11708    1210825.1    103.4      5.0      thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
   126                                           
   127     11708    1368071.0    116.8      5.6      phi_y = torch.atan2(v[:, 0], v[:, 2])
   128     11708    2776173.9    237.1     11.4      square_sum = v[:,0]**2 + v[:,2]**2
   129     11708    5167081.1    441.3     21.3      inverted = fast_inv_sqrt(square_sum)
   130     11708    1467019.2    125.3      6.0      phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
   131                                           
   132     11708     490739.4     41.9      2.0      cos_y = torch.cos(phi_y)
   133     11708     476400.1     40.7      2.0      sin_y = torch.sin(phi_y)
   134     11708     461560.0     39.4      1.9      cos_x = torch.cos(phi_x)
   135     11708     450491.7     38.5      1.9      sin_x = torch.sin(phi_x)
   136                                           
   137                                           
   138     35124    1041920.8     29.7      4.3      R = torch.stack(
   139     23416    1458563.3     62.3      6.0              [cos_y, -sin_y*sin_x, sin_y*cos_x,
   140     11708     591757.9     50.5      2.4              torch.zeros_like(cos_x), cos_x, sin_x,
   141     11708    1800907.1    153.8      7.4              -sin_y, -cos_y*sin_x, cos_y*cos_x]
   142     11708     339204.8     29.0      1.4      ).T.reshape(n,3,3)
   143                                               # (n,3,3)
   144                                           
   145                                           
   146     23416     800002.8     34.2      3.3      vectors = torch.stack(
   147     11708      13927.3      1.2      0.1          [
   148     11708     506071.3     43.2      2.1              torch.cos(thetas), 
   149     11708     470040.7     40.1      1.9              torch.sin(thetas), 
   150     11708     546612.1     46.7      2.2              torch.zeros_like(thetas)
   151                                                   ],
   152                                               ) # (3,m)
   153                                           
   154     11708    2828458.1    241.6     11.6      return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)
EricPedley commented 1 month ago

Switching the device from cuda back to cpu cuts the runtime by 2/3 🤡. So vectorizing made it faster but surprisingly, putting it on cuda made it super slow.

 kernprof -lv tests/imaging/drone_tracker_tests.py 
20it [00:20,  1.02s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s

Total time: 15.8699 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   117                                               @profile
   118                                               def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
   119                                                   '''
   120                                                   `states` is (n, 7)
   121                                                   returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
   122                                                   '''
   123     23348      81838.9      3.5      0.5          cam = CameraModel(self.focal_len_pixels, 
   124     11674      25268.3      2.2      0.2                          [self.resolution[0]/2, self.resolution[1]/2], 
   125     11674     168326.5     14.4      1.1                          cam_pose[1].as_matrix(), 
   126     11674      43190.4      3.7      0.3                          cam_pose[0].reshape(3,1))
   127                                           
   128     11674      29753.3      2.5      0.2          n = states.shape[0]
   129     11674     241081.1     20.7      1.5          positions = states[:, :3]
   130     11674     187551.4     16.1      1.2          radii = states[:, -1]
   131                                           
   132     11674     452789.2     38.8      2.9          cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
   133     11674     154036.6     13.2      1.0          rays_to_center = positions - cam_position_tensor
   134                                                   
   135     11674       7452.6      0.6      0.0          n_samples = 25
   136     11674    6161340.0    527.8     38.8          orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
   137                                           
   138     11674     792321.9     67.9      5.0          pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
   139                                           
   140     11674     195496.6     16.7      1.2          pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
   141                                           
   142                                                   # project points into the camera
   143     11674    4111067.4    352.2     25.9          projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
   144     11674     733164.0     62.8      4.6          x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
   145     11674     696727.6     59.7      4.4          y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
   146     11674     690030.3     59.1      4.3          x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
   147     11674     646813.5     55.4      4.1          y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
   148                                           
   149     11674     451626.7     38.7      2.8          return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)

Total time: 5.57248 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   109                                           @profile
   110                                           def make_ortho_vectors(v: torch.Tensor, m: int):
   111                                               '''
   112                                               `v` is a (n,3) tensor
   113                                               make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
   114                                               
   115                                               to visualize: imagine each v_i is the vector coinciding 
   116                                               with a lion's face direction, and we wish to make m vectors for the lion's mane.
   117                                           
   118                                               it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
   119                                               with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
   120                                               that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
   121                                           
   122                                               returns a tensor of shape (n,m,3)
   123                                               '''
   124     11674      28702.3      2.5      0.5      n = v.shape[0]
   125     11674     271867.9     23.3      4.9      thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
   126                                           
   127     11674     595917.9     51.0     10.7      phi_y = torch.atan2(v[:, 0], v[:, 2])
   128     11674     994281.7     85.2     17.8      square_sum = v[:,0]**2 + v[:,2]**2
   129     11674     623859.7     53.4     11.2      inverted = 1/torch.sqrt(square_sum)#fast_inv_sqrt(square_sum)
   130     11674     530361.5     45.4      9.5      phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
   131                                           
   132     11674     123518.0     10.6      2.2      cos_y = torch.cos(phi_y)
   133     11674     106331.2      9.1      1.9      sin_y = torch.sin(phi_y)
   134     11674      82815.4      7.1      1.5      cos_x = torch.cos(phi_x)
   135     11674      74875.7      6.4      1.3      sin_x = torch.sin(phi_x)
   136                                           
   137                                           
   138     35022     340320.8      9.7      6.1      R = torch.stack(
   139     23348     227878.7      9.8      4.1              [cos_y, -sin_y*sin_x, sin_y*cos_x,
   140     11674     111268.5      9.5      2.0              torch.zeros_like(cos_x), cos_x, sin_x,
   141     11674     216264.4     18.5      3.9              -sin_y, -cos_y*sin_x, cos_y*cos_x]
   142     11674     240959.3     20.6      4.3      ).T.reshape(n,3,3)
   143                                               # (n,3,3)
   144                                           
   145                                           
   146     23348     229906.1      9.8      4.1      vectors = torch.stack(
   147     11674       8039.4      0.7      0.1          [
   148     11674     109889.4      9.4      2.0              torch.cos(thetas), 
   149     11674      84299.1      7.2      1.5              torch.sin(thetas), 
   150     11674      91818.8      7.9      1.6              torch.zeros_like(thetas)
   151                                                   ],
   152                                               ) # (3,m)
   153                                           
   154     11674     479299.2     41.1      8.6      return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)