summary refs log tree commit diff stats
path: root/day8.py
blob: b78bd1aef622dcb62850dc29387d473b5ef57be8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python

with open("day8.txt") as data:
    signals, outputs = [], []
    for line in data:
        signal, output = line.strip().split('|')
        signal, output = signal.strip().split(), output.strip().split()
        signals.append(signal)
        outputs.append(output)
# part 1
total = 0
for output in outputs:
    for digit in output:
        match len(digit):
            case 2 | 3 | 4 | 7:
                total += 1
            case _:
                continue
print(total)

# part 2

def lookup(signalmap):
    inverted = {frozenset(v): str(k) for k, v in signalmap.items()}
    def lookup_func(signals):
        return inverted[frozenset(signals)]
    return lookup_func

total = 0
for signal, output in zip(signals, outputs):
    # determine what goes where
    signalmap = {}
    signal.sort(key=lambda x: len(x))
    signalmap[1] = set(signal[0]) # 2 segments
    signalmap[7] = set(signal[1]) # 3 segments
    signalmap[4] = set(signal[2]) # 4 segments
    signalmap[8] = set(signal[9]) # 7 segments

    # 5 segments: 2, 3, 5
    for digit in map(set, signal[3:6]):
        if signalmap[1].issubset(digit):
            signalmap[3] = digit
        elif len(digit & signalmap[4]) == 2:
            signalmap[2] = digit
        elif len(digit & signalmap[4]) == 3:
            signalmap[5] = digit
    # 6 segments: 0, 6, 9
    for digit in map(set, signal[6:9]):
        if signalmap[4].issubset(digit):
            signalmap[9] = digit
        elif signalmap[1].issubset(digit):
            signalmap[0] = digit
        elif len(signalmap[1] & digit) == 1:
            signalmap[6] = digit

    lookup_func = lookup(signalmap)
    # decode output
    out_string = "".join(lookup_func(x) for x in output)
    total += int(out_string)

print(total)