summary refs log tree commit diff stats
path: root/day16.py
blob: 5adedfc425666bc4a41b94a898ec8450044d7804 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python
import math

hex2bin = {digit : f'{i:>04b}' for i, digit in enumerate("0123456789ABCDEF")}
with open('day16.txt') as data:
    bits = "".join(hex2bin[char] for char in data.read().strip())

pos = 0
def read_bits(num_bits):
    global pos
    num = bits[pos:pos+num_bits]
    pos += num_bits
    return num

def read_int(num_bits):
# reads arbitrary number of bits from bitstring to a number
    bit_str = read_bits(num_bits)
    return int(bit_str, 2)

def read_literal_int():
# reads the value of a literal: type id = 4
    total = 0
    while True:
        digit = read_bits(5)
        total <<= 4
        total += int(digit[1:], 2)
        if digit[0] == '0':
            break
    return total

def parse_packet():
    version = read_int(3)
    type_id = read_int(3)
    if type_id == 4:
        literal = read_literal_int()
        return {
            'version': version,
            'type_id': type_id,
            'literal': literal
        }
    else:
        length_id = read_int(1)
        if length_id == 0:
            bit_length = read_int(15)
            stop_pos = pos + bit_length
            subpackets = []
            while pos < stop_pos:
                p = parse_packet()
                subpackets.append(p)
            return {
                'version': version,
                'type_id': type_id,
                'packets': subpackets
            }
        else:
            sub_length = read_int(11)
            subpackets = []
            for i in range(sub_length):
                subpackets.append(parse_packet())
            return {
                'version': version,
                'type_id': type_id,
                'packets': subpackets
            }

def get_version_sum(packet):
    version_sum = packet['version']
    if 'literal' not in packet:
        for subpacket in packet['packets']:
            version_sum += get_version_sum(subpacket)
    return version_sum

def evaluate_packet(packet):
    match packet['type_id']:
        case 0:
            return sum(evaluate_packet(sub) for sub in packet['packets'])
        case 1:
            return math.prod(evaluate_packet(sub) for sub in packet['packets'])
        case 2:
            return min(evaluate_packet(sub) for sub in packet['packets'])
        case 3:
            return max(evaluate_packet(sub) for sub in packet['packets'])
        case 4:
            return packet['literal']
        case 5:
            return int(evaluate_packet(packet['packets'][0]) > evaluate_packet(packet['packets'][1]))
        case 6:
            return int(evaluate_packet(packet['packets'][0]) < evaluate_packet(packet['packets'][1]))
        case 7:
            return int(evaluate_packet(packet['packets'][0]) == evaluate_packet(packet['packets'][1]))

parsed = parse_packet()
# part 1
print(get_version_sum(parsed))

# part 2
print(evaluate_packet(parsed))