diff --git a/day08/part2.lua b/day08/part2.lua new file mode 100755 index 0000000..39698e0 --- /dev/null +++ b/day08/part2.lua @@ -0,0 +1,110 @@ +#!/usr/bin/env luajit +require("utils") +require("set") + +local entries = {} +for line in io.lines(arg[3]) do + local entry = split(line, " ") + table.remove(entry, 11) + table.insert(entries, entry) +end + +function uniques(segment) + if #segment == 2 then return "cf" end + if #segment == 3 then return "acf" end + if #segment == 4 then return "bcdf" end +end + +function reduceoptions(options, segment) + local included = uniques(segment) + if included then + for part in chars("abcdgef") do + if string.find(segment, part) then + options[part]:keeponly(chars(included)) + else + options[part]:removeall(chars(included)) + end + end + end +end + +function printoptions(options) + for from, s in pairs(options) do + io.write(from..":") + for to in s:iter() do + io.write(to) + end + io.write(" ") + end + io.write("\n") +end + +function asdigit(segment, sel) + local lit = set() + for c in chars(segment) do + lit:add(sel[c]) + end + + if lit:equal(set(chars("acbefg"))) then return 0 end + if lit:equal(set(chars("cf"))) then return 1 end + if lit:equal(set(chars("acdeg"))) then return 2 end + if lit:equal(set(chars("acdfg"))) then return 3 end + if lit:equal(set(chars("bcdf"))) then return 4 end + if lit:equal(set(chars("abdfg"))) then return 5 end + if lit:equal(set(chars("abdefg"))) then return 6 end + if lit:equal(set(chars("acf"))) then return 7 end + if lit:equal(set(chars("abcdefg"))) then return 8 end + if lit:equal(set(chars("abcdfg"))) then return 9 end +end + +function backtrack(options, displays, display, segment, acc, sel, rev) + if display > #displays then + return set({acc}) + elseif segment <= #displays[display] then + local pickfor = char(displays[display], segment) + local picked = sel[pickfor] + if picked then + return backtrack(options, displays, display, segment+1, acc, sel, rev) + else + local accs = set() + for picked in options[char(displays[display], segment)]:iter() do + sel[pickfor] = picked + if not rev:has(picked) then + rev:add(picked) + accs:addall(backtrack(options, displays, display, segment+1, acc, sel, rev):iter()) + rev:remove(picked) + end + end + sel[pickfor] = nil + return accs + end + else + local digit = asdigit(displays[display], sel) + if digit then + return backtrack(options, displays, display + 1, 1, (acc * 10 + digit) % 10000, sel, rev) + else + return set() + end + end +end + +function outputvalue(entry) + local options = { a = set(chars('abcdefg')), + b = set(chars('abcdefg')), + c = set(chars('abcdefg')), + d = set(chars('abcdefg')), + e = set(chars('abcdefg')), + f = set(chars('abcdefg')), + g = set(chars('abcdefg')) } + + for k, v in pairs(entry) do reduceoptions(options, v) end + local solutions = backtrack(options, entry, 1, 1, 0, {}, set()) + if solutions.size == 1 then + return solutions:iter()() + else + error("meh") + end +end + +foreach(entries, outputvalue) +print(sum(entries)) diff --git a/set.lua b/set.lua new file mode 100644 index 0000000..5621c48 --- /dev/null +++ b/set.lua @@ -0,0 +1,95 @@ +local function add(self, value) + if not self.values[value] then + self.values[value] = true + self.size = self.size + 1 + end +end + +local function has(self, value) + return self.values[value] ~= nil +end + +local function remove(self, value) + if self.values[value] then + self.values[value] = nil + self.size = self.size - 1 + end +end + +local function iter(self) + local index, item + return function() + repeat + index, item = next(self.values, index) + until index == nil or item ~= nil + if index == nil then + return nil + else + return index + end + end +end + +local function keeponly(self, iterable) + local tokeep = set(iterable) + for v in self:iter() do + if not tokeep:has(v) then + self:remove(v) + end + end +end + +local function removeall(self, iterable) + for v in iterable do + self:remove(v) + end +end + +local function addall(self, iterable) + for v in iterable do + self:add(v) + end +end + +local function equal(self, other) + for c in self:iter() do if not other:has(c) then return false end end + for c in other:iter() do if not self:has(c) then return false end end + return true +end + +local function write(self) + io.write('{') + for v in self:iter() do + io.write(' '..v) + end + io.write(' }\n') +end + +function set(t) + local values = {} + local size = 0 + if type(t) == "table" then + for k, v in pairs(t) do + size = size + 1 + values[v] = true + end + elseif type(t) == "function" then + for v in t do + size = size + 1 + values[v] = true + end + end + return { + values = values, + size = size, + add = add, + has = has, + remove = remove, + iter = iter, + keeponly = keeponly, + removeall = removeall, + addall = addall, + equal = equal, + write = write, + } +end diff --git a/utils.lua b/utils.lua index 9dae1cb..1836650 100644 --- a/utils.lua +++ b/utils.lua @@ -38,3 +38,30 @@ function count(t, pred) end return count end + +function foldl(t, f, a) + for k, v in pairs(t) do + a = f(a, v) + end + return a +end + +function sum(t) + return foldl(t, function (x, y) return x + y end, 0) +end + +function char(s, index) + return string.char(string.byte(s, index)) +end + +function chars(s) + local index = 0 + return function() + if index < #s then + index = index + 1 + return char(s, index) + else + return nil + end + end +end