summary refs log tree commit diff stats
path: root/day16.py
diff options
context:
space:
mode:
Diffstat (limited to 'day16.py')
-rw-r--r--day16.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/day16.py b/day16.py
new file mode 100644
index 0000000..5adedfc
--- /dev/null
+++ b/day16.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+import math
+
+hex2bin = {digit : f'{i:>04b}' for i, digit in enumerate("0123456789ABCDEF")}
+with open('day16.txt') as data:
+    bits = "".join(hex2bin[char] for char in data.read().strip())
+
+pos = 0
+def read_bits(num_bits):
+    global pos
+    num = bits[pos:pos+num_bits]
+    pos += num_bits
+    return num
+
+def read_int(num_bits):
+# reads arbitrary number of bits from bitstring to a number
+    bit_str = read_bits(num_bits)
+    return int(bit_str, 2)
+
+def read_literal_int():
+# reads the value of a literal: type id = 4
+    total = 0
+    while True:
+        digit = read_bits(5)
+        total <<= 4
+        total += int(digit[1:], 2)
+        if digit[0] == '0':
+            break
+    return total
+
+def parse_packet():
+    version = read_int(3)
+    type_id = read_int(3)
+    if type_id == 4:
+        literal = read_literal_int()
+        return {
+            'version': version,
+            'type_id': type_id,
+            'literal': literal
+        }
+    else:
+        length_id = read_int(1)
+        if length_id == 0:
+            bit_length = read_int(15)
+            stop_pos = pos + bit_length
+            subpackets = []
+            while pos < stop_pos:
+                p = parse_packet()
+                subpackets.append(p)
+            return {
+                'version': version,
+                'type_id': type_id,
+                'packets': subpackets
+            }
+        else:
+            sub_length = read_int(11)
+            subpackets = []
+            for i in range(sub_length):
+                subpackets.append(parse_packet())
+            return {
+                'version': version,
+                'type_id': type_id,
+                'packets': subpackets
+            }
+
+def get_version_sum(packet):
+    version_sum = packet['version']
+    if 'literal' not in packet:
+        for subpacket in packet['packets']:
+            version_sum += get_version_sum(subpacket)
+    return version_sum
+
+def evaluate_packet(packet):
+    match packet['type_id']:
+        case 0:
+            return sum(evaluate_packet(sub) for sub in packet['packets'])
+        case 1:
+            return math.prod(evaluate_packet(sub) for sub in packet['packets'])
+        case 2:
+            return min(evaluate_packet(sub) for sub in packet['packets'])
+        case 3:
+            return max(evaluate_packet(sub) for sub in packet['packets'])
+        case 4:
+            return packet['literal']
+        case 5:
+            return int(evaluate_packet(packet['packets'][0]) > evaluate_packet(packet['packets'][1]))
+        case 6:
+            return int(evaluate_packet(packet['packets'][0]) < evaluate_packet(packet['packets'][1]))
+        case 7:
+            return int(evaluate_packet(packet['packets'][0]) == evaluate_packet(packet['packets'][1]))
+
+parsed = parse_packet()
+# part 1
+print(get_version_sum(parsed))
+
+# part 2
+print(evaluate_packet(parsed))