diff --git a/day09/part2.lua b/day09/part2.lua index ef2f19e..e9ddea1 100755 --- a/day09/part2.lua +++ b/day09/part2.lua @@ -1,5 +1,6 @@ #!/usr/bin/env luajit require("utils") +require("unionfind") local hm = {} for line in io.lines(arg[3]) do @@ -14,58 +15,24 @@ local function get(r, c) return (hm[r] or {})[c] or 9 end -local union = {} +local union = unionfind() for r, row in pairs(hm) do for c, col in pairs(row) do if col ~= 9 then if get(r - 1, c) ~= 9 then - if union[r] == nil then union[r] = {} end - union[r][c] = { r - 1, c } + union:makeset(r, c) + union:makeset(r - 1, c) + union:union(r, c, r - 1, c) end if get(r, c - 1) ~= 9 then - if union[r] == nil then union[r] = {} end - if union[r][c] ~= nil then - local r0, c0 = r, c - 1 - while type(union[r0][c0]) == "table" do - r0, c0 = unpack(union[r0][c0]) - if union[r0] == nil then union[r0] = {} end - end - local r1, c1 = r, c - while type(union[r1][c1]) == "table" do - r1, c1 = unpack(union[r1][c1]) - if union[r1] == nil then union[r1] = {} end - end - if r0 ~= r1 or c0 ~= c1 then - union[r1][c1] = { r0, c0 } - end - else - union[r][c] = { r, c - 1 } - end + union:makeset(r, c) + union:makeset(r, c - 1) + union:union(r, c, r, c - 1) end end end end -local sums = {} -for r, row in pairs(hm) do - for c, col in pairs(row) do - if col ~= 9 then - local r0, c0 = r, c - while type(union[r0][c0]) == "table" do - r0, c0 = unpack(union[r0][c0]) - end - if sums[r0] == nil then sums[r0] = {} end - sums[r0][c0] = (sums[r0][c0] or 0) + 1 - end - end -end - -local sizes = {} -for r, row in pairs(sums) do - for c, col in pairs(row) do - table.insert(sizes, col) - end -end - +local sizes = union:sizes() table.sort(sizes) print(sizes[#sizes] * sizes[#sizes - 1] * sizes[#sizes - 2]) diff --git a/unionfind.lua b/unionfind.lua new file mode 100644 index 0000000..5379901 --- /dev/null +++ b/unionfind.lua @@ -0,0 +1,51 @@ +local function makeset(self, r, c) + if self.forest[r] == nil then self.forest[r] = {} end + if self.forest[r][c] == nil then + self.forest[r][c] = { r, c, 1 } + end +end + +local function find(self, r, c) + if self.forest[r] == nil then self.forest[r] = {} end + local r0, c0, s0 = unpack(self.forest[r][c] or {}) + if r0 == r and c0 == c then + return r0, c0, s0 + else + self.forest[r][c] = { self:find(r0, c0) } + return unpack(self.forest[r][c]) + end +end + +local function union(self, ar, ac, br, bc) + a0r, a0c, a0s = self:find(ar, ac) + b0r, b0c, b0s = self:find(br, bc) + if a0r == b0r and a0c == b0c then return end + if a0s < b0s then + a0r, a0c, a0s, b0r, b0c, b0s = b0r, b0c, b0s, a0r, a0c, a0s + end + self.forest[b0r][b0c] = { a0r, a0c, a0s + b0s } + self.forest[a0r][a0c] = { a0r, a0c, a0s + b0s } +end + +local function sizes(self) + local sizes = {} + for r, row in pairs(self.forest) do + for c, col in pairs(row) do + if r == col[1] and c == col[2] then + table.insert(sizes, col[3]) + end + end + end + return sizes +end + +function unionfind() + local forest = {} + return { + forest = forest, + makeset = makeset, + find = find, + union = union, + sizes = sizes, + } +end