Closed EricPedley closed 1 month ago
Benchmarking results from the orin:
forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py
20it [01:16, 3.81s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s
Total time: 65.4881 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117
Line # Hits Time Per Hit % Time Line Contents
==============================================================
117 @profile
118 def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
119 '''
120 `states` is (n, 7)
121 returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
122 '''
123 23260 100939.8 4.3 0.2 cam = CameraModel(self.focal_len_pixels,
124 11630 27494.0 2.4 0.0 [self.resolution[0]/2, self.resolution[1]/2],
125 11630 228127.5 19.6 0.3 cam_pose[1].as_matrix(),
126 11630 50272.7 4.3 0.1 cam_pose[0].reshape(3,1))
127
128 11630 44174.0 3.8 0.1 n = states.shape[0]
129 11630 381818.1 32.8 0.6 positions = states[:, :3]
130 11630 317663.1 27.3 0.5 radii = states[:, -1]
131
132 11630 1312691.4 112.9 2.0 cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
133 11630 594472.2 51.1 0.9 rays_to_center = positions - cam_position_tensor
134
135 11630 12479.0 1.1 0.0 n_samples = 25
136 11630 37973416.9 3265.1 58.0 orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
137
138 11630 1852451.1 159.3 2.8 pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
139
140 11630 296315.5 25.5 0.5 pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
141
142 # project points into the camera
143 11630 16447452.0 1414.2 25.1 projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
144 11630 1283011.3 110.3 2.0 x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
145 11630 1186818.7 102.0 1.8 y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
146 11630 1225458.6 105.4 1.9 x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
147 11630 1167045.2 100.3 1.8 y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
148
149 11630 985989.0 84.8 1.5 return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)
Total time: 14.5859 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: update at line 151
Line # Hits Time Per Hit % Time Line Contents
==============================================================
151 @profile
152 def update(self, cam_pose: tuple[np.ndarray, Rotation], measurement:Measurement):
153 '''
154 measurements is a list of 2D integer bounding boxes in pixel coordinates (x1,y1,x2,y2)
155 '''
156
157 # add particles to `samples` that would line up with the measurements
158 # self.samples.extend(
159 # self.gen_samples_from_measurement(cam_pose, measurement.box, 10)
160 # )
161
162 19 14454575.8 760767.1 99.1 measurements = self.compute_measurements(cam_pose, self.samples)
163 19 87296.3 4594.5 0.6 likelihoods = self.compute_likelihoods(measurements, measurement.box.to_xyxy())
164
165 # resample the particles
166 19 44032.7 2317.5 0.3 self.resample(likelihoods)
Baseline benchmarking without these changes (but still with the plotting part commented-out):
forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py
20it [00:44, 2.24s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s
Total time: 26.6667 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurement at line 121
Line # Hits Time Per Hit % Time Line Contents
==============================================================
121 @profile
122 def compute_measurement(self, cam_pose: tuple[np.ndarray, Rotation], state: np.ndarray) -> BoundingBox:
123 '''
124 `state` is the state of the track, which is a 7 element array
125 '''
126 # if behind the camera, return a box with 0 area
127 73248 3959894.6 54.1 14.8 if np.dot(state[:3] - cam_pose[0], cam_pose[1].apply([0,0,1])) < 0:
128 1 4.4 4.4 0.0 return BoundingBox(0, 0, 0, 0)
129
130 146494 383314.2 2.6 1.4 cam = CameraModel(self.focal_len_pixels,
131 73247 101525.0 1.4 0.4 [self.resolution[0]/2, self.resolution[1]/2],
132 73247 472996.9 6.5 1.8 cam_pose[1].as_matrix(),
133 73247 180592.5 2.5 0.7 cam_pose[0].reshape(3,1))
134
135 73247 90192.4 1.2 0.3 state_position = state[:3]
136 73247 67216.9 0.9 0.3 state_radius = state[-1]
137
138 73247 262011.1 3.6 1.0 ray_to_center = state_position - cam_pose[0]
139
140 # monte carlo to find the circumscribed rectangle around the sphere's projection into the camera
141 # there's probably a better way to do this but I'm not sure what it is
142 # I tried a method where we project 4 points on the boundary and fit a 2d ellipse to their projection
143 # but the ellipse fitting was not working well
144 73247 29165.9 0.4 0.1 n_samples = 100
145
146 # sample points on the sphere
147 73247 1464999.2 20.0 5.5 random_vector = np.random.randn(3, n_samples)
148 73247 4822581.3 65.8 18.1 random_vector -= np.dot(random_vector.T, ray_to_center) * np.repeat([ray_to_center / np.linalg.norm(ray_to_center)], n_samples, axis=0).T
149 73247 2824607.9 38.6 10.6 random_vector = random_vector / np.linalg.norm(random_vector, axis=0) * state_radius
150
151 # project points into the camera
152 73247 7656917.1 104.5 28.7 projected_points = cam.project(state_position.reshape((3,1)) + random_vector)
153 73247 1395048.9 19.0 5.2 x_min = np.min(projected_points[0])
154 73247 951478.9 13.0 3.6 x_max = np.max(projected_points[0])
155 73247 807701.3 11.0 3.0 y_min = np.min(projected_points[1])
156 73247 776503.6 10.6 2.9 y_max = np.max(projected_points[1])
157
158 73247 419950.5 5.7 1.6 return BoundingBox((x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min)
Total time: 22.3501 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: update at line 160
Line # Hits Time Per Hit % Time Line Contents
==============================================================
160 @profile
161 def update(self, cam_pose: tuple[np.ndarray, Rotation], measurement:Measurement):
162 '''
163 measurements is a list of 2D integer bounding boxes in pixel coordinates (x1,y1,x2,y2)
164 '''
165
166 # add particles to `samples` that would line up with the measurements
167 # self.samples.extend(
168 # self.gen_samples_from_measurement(cam_pose, measurement.box, 10)
169 # )
170
171 19019 27968.6 1.5 0.1 for i, particle in enumerate(self.samples):
172
173 # compute the likelihood of the particle given the measurement
174 # by comparing the measurement to the particle's predicted
175 # measurement
176 19000 8726033.1 459.3 39.0 predicted_measurement = self.compute_measurement(cam_pose, particle.state).to_xyxy()
177 19000 12061590.6 634.8 54.0 particle.likelihood = self.compute_likelihood(predicted_measurement, measurement.box.to_xyxy())
178
179 # resample the particles
180 19 1534535.3 80765.0 6.9 self.resample()
Seems that the slowest parts are trying to get the vectors orthogonal to the camera look vectors to get the sphere boundaries, and projecting them into the camera. Also, fast inv square root isn't actually faster than 1/torch.sqrt(x)
:(
Without fast inv sqrt:
forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py
20it [01:01, 3.09s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s
Total time: 50.4935 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117
Line # Hits Time Per Hit % Time Line Contents
==============================================================
117 @profile
118 def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
119 '''
120 `states` is (n, 7)
121 returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
122 '''
123 23452 106006.8 4.5 0.2 cam = CameraModel(self.focal_len_pixels,
124 11726 27706.6 2.4 0.1 [self.resolution[0]/2, self.resolution[1]/2],
125 11726 228439.9 19.5 0.5 cam_pose[1].as_matrix(),
126 11726 51397.0 4.4 0.1 cam_pose[0].reshape(3,1))
127
128 11726 43833.8 3.7 0.1 n = states.shape[0]
129 11726 417475.7 35.6 0.8 positions = states[:, :3]
130 11726 347961.7 29.7 0.7 radii = states[:, -1]
131
132 11726 1407496.1 120.0 2.8 cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
133 11726 631097.3 53.8 1.2 rays_to_center = positions - cam_position_tensor
134
135 11726 12367.5 1.1 0.0 n_samples = 25
136 11726 21668508.7 1847.9 42.9 orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
137
138 11726 1967327.5 167.8 3.9 pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
139
140 11726 284906.5 24.3 0.6 pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
141
142 # project points into the camera
143 11726 17022874.9 1451.7 33.7 projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
144 11726 1355292.4 115.6 2.7 x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
145 11726 1301534.9 111.0 2.6 y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
146 11726 1335170.3 113.9 2.6 x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
147 11726 1268481.5 108.2 2.5 y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
148
149 11726 1015606.3 86.6 2.0 return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)
Total time: 20.6904 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109
Line # Hits Time Per Hit % Time Line Contents
==============================================================
109 @profile
110 def make_ortho_vectors(v: torch.Tensor, m: int):
111 '''
112 `v` is a (n,3) tensor
113 make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
114
115 to visualize: imagine each v_i is the vector coinciding
116 with a lion's face direction, and we wish to make m vectors for the lion's mane.
117
118 it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
119 with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
120 that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
121
122 returns a tensor of shape (n,m,3)
123 '''
124 11726 46500.6 4.0 0.2 n = v.shape[0]
125 11726 1202476.2 102.5 5.8 thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
126
127 11726 1352173.7 115.3 6.5 phi_y = torch.atan2(v[:, 0], v[:, 2])
128 11726 2734895.5 233.2 13.2 square_sum = v[:,0]**2 + v[:,2]**2
129 11726 1800831.5 153.6 8.7 inverted = 1/torch.sqrt(square_sum)#fast_inv_sqrt(square_sum)
130 11726 1439063.6 122.7 7.0 phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
131
132 11726 487013.4 41.5 2.4 cos_y = torch.cos(phi_y)
133 11726 485644.9 41.4 2.3 sin_y = torch.sin(phi_y)
134 11726 453987.2 38.7 2.2 cos_x = torch.cos(phi_x)
135 11726 447829.4 38.2 2.2 sin_x = torch.sin(phi_x)
136
137
138 35178 1034150.2 29.4 5.0 R = torch.stack(
139 23452 1423668.2 60.7 6.9 [cos_y, -sin_y*sin_x, sin_y*cos_x,
140 11726 575882.8 49.1 2.8 torch.zeros_like(cos_x), cos_x, sin_x,
141 11726 1732036.2 147.7 8.4 -sin_y, -cos_y*sin_x, cos_y*cos_x]
142 11726 337977.1 28.8 1.6 ).T.reshape(n,3,3)
143 # (n,3,3)
144
145
146 23452 789558.7 33.7 3.8 vectors = torch.stack(
147 11726 15118.9 1.3 0.1 [
148 11726 504119.3 43.0 2.4 torch.cos(thetas),
149 11726 466240.2 39.8 2.3 torch.sin(thetas),
150 11726 546511.6 46.6 2.6 torch.zeros_like(thetas)
151 ],
152 ) # (3,m)
153
154 11726 2814702.0 240.0 13.6 return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)
With fast inv sqrt:
forge@uav-forge-orin:~/uavf_2024
> kernprof -lv tests/imaging/drone_tracker_tests.py
20it [01:05, 3.28s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s
Total time: 54.3264 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117
Line # Hits Time Per Hit % Time Line Contents
==============================================================
117 @profile
118 def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
119 '''
120 `states` is (n, 7)
121 returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
122 '''
123 23416 102576.1 4.4 0.2 cam = CameraModel(self.focal_len_pixels,
124 11708 26484.2 2.3 0.0 [self.resolution[0]/2, self.resolution[1]/2],
125 11708 228881.3 19.5 0.4 cam_pose[1].as_matrix(),
126 11708 52137.3 4.5 0.1 cam_pose[0].reshape(3,1))
127
128 11708 44564.4 3.8 0.1 n = states.shape[0]
129 11708 418965.1 35.8 0.8 positions = states[:, :3]
130 11708 354484.2 30.3 0.7 radii = states[:, -1]
131
132 11708 1411121.9 120.5 2.6 cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
133 11708 634851.9 54.2 1.2 rays_to_center = positions - cam_position_tensor
134
135 11708 11406.4 1.0 0.0 n_samples = 25
136 11708 25274502.8 2158.7 46.5 orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
137
138 11708 1994273.8 170.3 3.7 pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
139
140 11708 287269.6 24.5 0.5 pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
141
142 # project points into the camera
143 11708 17174573.5 1466.9 31.6 projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
144 11708 1356988.1 115.9 2.5 x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
145 11708 1302896.4 111.3 2.4 y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
146 11708 1346783.8 115.0 2.5 x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
147 11708 1274790.4 108.9 2.3 y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
148
149 11708 1028859.9 87.9 1.9 return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)
Total time: 4.82129 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: fast_inv_sqrt at line 92
Line # Hits Time Per Hit % Time Line Contents
==============================================================
92 @profile
93 def fast_inv_sqrt(x: torch.Tensor):
94 '''
95 Calculates 1/sqrt(x) really fast.
96 If x is (n,) this will be vectorized too
97
98 '''
99 11708 11281.0 1.0 0.2 three_halfs = 1.5
100 11708 668128.5 57.1 13.9 x2 = x * 0.5
101 11708 11269.4 1.0 0.2 y = x
102 11708 161805.5 13.8 3.4 i = x.view(torch.int32)
103 11708 1532964.6 130.9 31.8 i = 0x5f3759df - (i>>1)
104 11708 132011.4 11.3 2.7 y = i.view(torch.float32)
105 11708 2292938.0 195.8 47.6 y = y * (three_halfs - (x2 * y * y))
106
107 11708 10894.5 0.9 0.2 return y
Total time: 24.3137 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109
Line # Hits Time Per Hit % Time Line Contents
==============================================================
109 @profile
110 def make_ortho_vectors(v: torch.Tensor, m: int):
111 '''
112 `v` is a (n,3) tensor
113 make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
114
115 to visualize: imagine each v_i is the vector coinciding
116 with a lion's face direction, and we wish to make m vectors for the lion's mane.
117
118 it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
119 with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
120 that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
121
122 returns a tensor of shape (n,m,3)
123 '''
124 11708 47894.3 4.1 0.2 n = v.shape[0]
125 11708 1210825.1 103.4 5.0 thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
126
127 11708 1368071.0 116.8 5.6 phi_y = torch.atan2(v[:, 0], v[:, 2])
128 11708 2776173.9 237.1 11.4 square_sum = v[:,0]**2 + v[:,2]**2
129 11708 5167081.1 441.3 21.3 inverted = fast_inv_sqrt(square_sum)
130 11708 1467019.2 125.3 6.0 phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
131
132 11708 490739.4 41.9 2.0 cos_y = torch.cos(phi_y)
133 11708 476400.1 40.7 2.0 sin_y = torch.sin(phi_y)
134 11708 461560.0 39.4 1.9 cos_x = torch.cos(phi_x)
135 11708 450491.7 38.5 1.9 sin_x = torch.sin(phi_x)
136
137
138 35124 1041920.8 29.7 4.3 R = torch.stack(
139 23416 1458563.3 62.3 6.0 [cos_y, -sin_y*sin_x, sin_y*cos_x,
140 11708 591757.9 50.5 2.4 torch.zeros_like(cos_x), cos_x, sin_x,
141 11708 1800907.1 153.8 7.4 -sin_y, -cos_y*sin_x, cos_y*cos_x]
142 11708 339204.8 29.0 1.4 ).T.reshape(n,3,3)
143 # (n,3,3)
144
145
146 23416 800002.8 34.2 3.3 vectors = torch.stack(
147 11708 13927.3 1.2 0.1 [
148 11708 506071.3 43.2 2.1 torch.cos(thetas),
149 11708 470040.7 40.1 1.9 torch.sin(thetas),
150 11708 546612.1 46.7 2.2 torch.zeros_like(thetas)
151 ],
152 ) # (3,m)
153
154 11708 2828458.1 241.6 11.6 return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)
Switching the device from cuda back to cpu cuts the runtime by 2/3 🤡. So vectorizing made it faster but surprisingly, putting it on cuda made it super slow.
kernprof -lv tests/imaging/drone_tracker_tests.py
20it [00:20, 1.02s/it]
Wrote profile results to drone_tracker_tests.py.lprof
Timer unit: 1e-06 s
Total time: 15.8699 s
File: /home/forge/uavf_2024/uavf_2024/imaging/particle_filter.py
Function: compute_measurements at line 117
Line # Hits Time Per Hit % Time Line Contents
==============================================================
117 @profile
118 def compute_measurements(self, cam_pose: tuple[np.ndarray, Rotation], states: np.ndarray) -> BoundingBox:
119 '''
120 `states` is (n, 7)
121 returns ndarray of shape (n, 4) where the 4 elements are [x1,y1,x2,y2]
122 '''
123 23348 81838.9 3.5 0.5 cam = CameraModel(self.focal_len_pixels,
124 11674 25268.3 2.2 0.2 [self.resolution[0]/2, self.resolution[1]/2],
125 11674 168326.5 14.4 1.1 cam_pose[1].as_matrix(),
126 11674 43190.4 3.7 0.3 cam_pose[0].reshape(3,1))
127
128 11674 29753.3 2.5 0.2 n = states.shape[0]
129 11674 241081.1 20.7 1.5 positions = states[:, :3]
130 11674 187551.4 16.1 1.2 radii = states[:, -1]
131
132 11674 452789.2 38.8 2.9 cam_position_tensor = Tensor(cam_pose[0]).to(self._device)
133 11674 154036.6 13.2 1.0 rays_to_center = positions - cam_position_tensor
134
135 11674 7452.6 0.6 0.0 n_samples = 25
136 11674 6161340.0 527.8 38.8 orthogonal_rays_normalized = make_ortho_vectors(rays_to_center, n_samples)
137
138 11674 792321.9 67.9 5.0 pts3 = positions[:, None, :] + orthogonal_rays_normalized * radii[:,None,None]
139
140 11674 195496.6 16.7 1.2 pts3_flat = pts3.reshape(-1, 3) # (n*n_samples, 3)
141
142 # project points into the camera
143 11674 4111067.4 352.2 25.9 projected_points = cam.project(pts3_flat.T, self._device).T.reshape(n, n_samples, 2)
144 11674 733164.0 62.8 4.6 x_mins = torch.min(projected_points[:,:,0], dim=1).values # (n,)
145 11674 696727.6 59.7 4.4 y_mins = torch.min(projected_points[:,:,1], dim=1).values # (n,)
146 11674 690030.3 59.1 4.3 x_maxs = torch.max(projected_points[:,:,0], dim=1).values # (n,)
147 11674 646813.5 55.4 4.1 y_maxs = torch.max(projected_points[:,:,1], dim=1).values # (n,)
148
149 11674 451626.7 38.7 2.8 return torch.vstack([x_mins, y_mins, x_maxs, y_maxs]).T # (n, 4)
Total time: 5.57248 s
File: /home/forge/uavf_2024/uavf_2024/imaging/utils.py
Function: make_ortho_vectors at line 109
Line # Hits Time Per Hit % Time Line Contents
==============================================================
109 @profile
110 def make_ortho_vectors(v: torch.Tensor, m: int):
111 '''
112 `v` is a (n,3) tensor
113 make m unit vectors that are orthogonal to each v_i, and evenly spaced around v_i's radial symmetry
114
115 to visualize: imagine each v_i is the vector coinciding
116 with a lion's face direction, and we wish to make m vectors for the lion's mane.
117
118 it does this by making a "lion's mane" around the vector (0,0,1), which is easy with parameterizing
119 with theta and using (cos(theta), sin(theta), 0). Then, it figures out the 2DOF R_x @ R_y rotation matrix
120 that would rotate (0,0,1) into v_i, and applies it to those mane vectors.
121
122 returns a tensor of shape (n,m,3)
123 '''
124 11674 28702.3 2.5 0.5 n = v.shape[0]
125 11674 271867.9 23.3 4.9 thetas = torch.linspace(0, 2*torch.pi, m).to(v.device)
126
127 11674 595917.9 51.0 10.7 phi_y = torch.atan2(v[:, 0], v[:, 2])
128 11674 994281.7 85.2 17.8 square_sum = v[:,0]**2 + v[:,2]**2
129 11674 623859.7 53.4 11.2 inverted = 1/torch.sqrt(square_sum)#fast_inv_sqrt(square_sum)
130 11674 530361.5 45.4 9.5 phi_x = torch.atan(v[:, 1] * inverted) # This line is responsible for like 20-25% of the runtime of this function, so unironically if we implement fast inverse square root in pytorch we can get huge performance gains
131
132 11674 123518.0 10.6 2.2 cos_y = torch.cos(phi_y)
133 11674 106331.2 9.1 1.9 sin_y = torch.sin(phi_y)
134 11674 82815.4 7.1 1.5 cos_x = torch.cos(phi_x)
135 11674 74875.7 6.4 1.3 sin_x = torch.sin(phi_x)
136
137
138 35022 340320.8 9.7 6.1 R = torch.stack(
139 23348 227878.7 9.8 4.1 [cos_y, -sin_y*sin_x, sin_y*cos_x,
140 11674 111268.5 9.5 2.0 torch.zeros_like(cos_x), cos_x, sin_x,
141 11674 216264.4 18.5 3.9 -sin_y, -cos_y*sin_x, cos_y*cos_x]
142 11674 240959.3 20.6 4.3 ).T.reshape(n,3,3)
143 # (n,3,3)
144
145
146 23348 229906.1 9.8 4.1 vectors = torch.stack(
147 11674 8039.4 0.7 0.1 [
148 11674 109889.4 9.4 2.0 torch.cos(thetas),
149 11674 84299.1 7.2 1.5 torch.sin(thetas),
150 11674 91818.8 7.9 1.6 torch.zeros_like(thetas)
151 ],
152 ) # (3,m)
153
154 11674 479299.2 41.1 8.6 return torch.matmul(R, vectors).permute(0, 2, 1) # (n, m, 3)
Summary
Speeds up particle filter by around 4x. I think we could still try more performance optimizations but it'd be hard. The slowest thing rn is the measurement function, which has to find vectors on the edges of the bounding spheres and run camera projection on them. The camera projection is like ~20% of the runtime of that, and I think we could speed it up by using quaternions for the camera rotation instead of rotation matrices. Also, the code is somehow slower on the GPU than the CPU, despite the slowest parts being pytorch operations, which suggests to me something is wrong and we could do some debugging of how the data is being moved to GPU and operated on.