diff options
Diffstat (limited to 'day19.py')
-rw-r--r-- | day19.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/day19.py b/day19.py new file mode 100644 index 0000000..7e70d68 --- /dev/null +++ b/day19.py @@ -0,0 +1,137 @@ +from collections import Counter, defaultdict +import numpy as np + +def fit_homography(P1, P2): + # p is a size (N, 3) set of 3d points + # X is a size (N, 4) set of 3d points + + N, _ = P1.shape + M = np.zeros((N*3, 12)) + b = np.zeros((N*3, 1)) + for i, (p, y) in enumerate(zip(P1, P2)): + b[i*3:i*3+3, 0] = y + + M[i*3, :3] = p + M[i*3 + 1, 3:6] = p + M[i*3 + 2, 6:9] = p + + M[i*3: i*3 + 3, 9:] = np.eye(3) + + x = np.linalg.inv(M.T @ M) @ M.T @ b + + R = x[:9].reshape(3,3).round().astype(int) + t = x[9:].round().astype(int) + return R, t + +class SensorData: + def __init__(self, points: np.ndarray): + self.points = points + self.point_to_dists = defaultdict(set) + + self.sensors = np.zeros((1,3)) + + # Get dists for points + self._find_dists() + + def _find_dists(self): + self.point_to_dists = defaultdict(set) + for p1 in self.points: + for p2 in self.points: + if tuple(p1) == tuple(p2): continue + d = ((p1 - p2) ** 2).sum().item() + self.point_to_dists[tuple(p1)].add(d) + self.point_to_dists[tuple(p2)].add(d) + + def add_points(self, other): + matches = self.match_points(other) + + if matches is None or matches.shape[0] < 12: + return False + + P1, P2 = self.points[matches[:,0], :], other.points[matches[:,1], :] + + R, t = fit_homography(P2, P1) + + transformed = other.points @ R.T + t.T + + self.points = np.concatenate([self.points, transformed]) + self.points = np.unique(self.points, axis=0) + # self._find_dists() + for new_point, old_point in zip(transformed, other.points): + self.point_to_dists[tuple(new_point)].update(other.point_to_dists[tuple(old_point)]) + + # Also map the sensor location to store list of sensors + self.sensors = np.concatenate([ + self.sensors, + other.sensors @ R.T + t.T + ]) + + return True + + def match_points(self, other): + paired = [] + for j, p2 in enumerate(other.points): + for i, p1 in enumerate(self.points): + dists1 = self.point_to_dists[tuple(p1)] + dists2 = other.point_to_dists[tuple(p2)] + common = dists1.intersection(dists2) + + if len(common) < 11: continue + + paired.append((i, j)) + + if len(paired) == 12: + break + + if not paired: return None + + return np.stack(paired) + + def max_dist(self): + best = 0 + for s1 in self.sensors: + for s2 in self.sensors: + d = np.abs(s1 - s2).sum().item() + + best = max(d, best) + return int(best) + +def parse_input(contents): + sensors = contents.split('\n\n') + + output = [] + for sensor in sensors: + output.append( + np.stack( + [list(map(int, line.split(','))) + for line in sensor.splitlines()[1:]] + ) + ) + + return output + +def part_1(sensors): + sensors = [SensorData(np.copy(n)) for n in sensors] + + data = sensors[0] + to_pair = set(range(1, len(sensors))) + while to_pair: + again = set() + for i in to_pair: + if data.add_points(sensors[i]): continue + again.add(i) + + to_pair = again + + return data + +if __name__ == "__main__": + print('--- Part 1 ---') + with open('day19.txt') as f: + sensors = parse_input(f.read()) + + data = part_1(sensors) + print(data.points.shape[0]) + + print('--- Part 2 ---') + print(data.max_dist()) |