From e058039f35b99efa16a151906c0ea18816084ad6 Mon Sep 17 00:00:00 2001 From: Felix Van der Jeugt Date: Thu, 9 Dec 2021 11:30:02 +0100 Subject: [PATCH] generalize union find --- day09/part2.lua | 18 ++++++++++------ unionfind.lua | 57 +++++++++++++++++++++++++------------------------ 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/day09/part2.lua b/day09/part2.lua index e9ddea1..71092f4 100755 --- a/day09/part2.lua +++ b/day09/part2.lua @@ -10,6 +10,7 @@ for line in io.lines(arg[3]) do end table.insert(hm, row) end +local w = #hm[1] local function get(r, c) return (hm[r] or {})[c] or 9 @@ -20,19 +21,22 @@ for r, row in pairs(hm) do for c, col in pairs(row) do if col ~= 9 then if get(r - 1, c) ~= 9 then - union:makeset(r, c) - union:makeset(r - 1, c) - union:union(r, c, r - 1, c) + union:makeset(r*w + c) + union:makeset(r*w - w + c) + union:union(r*w + c, r*w - w + c) end if get(r, c - 1) ~= 9 then - union:makeset(r, c) - union:makeset(r, c - 1) - union:union(r, c, r, c - 1) + union:makeset(r*w + c) + union:makeset(r*w + c - 1) + union:union(r*w + c, r*w + c - 1) end end end end -local sizes = union:sizes() +local sizes = {} +for set, size in union:unions() do + table.insert(sizes, size) +end table.sort(sizes) print(sizes[#sizes] * sizes[#sizes - 1] * sizes[#sizes - 2]) diff --git a/unionfind.lua b/unionfind.lua index 5379901..d7f88b9 100644 --- a/unionfind.lua +++ b/unionfind.lua @@ -1,42 +1,43 @@ -local function makeset(self, r, c) - if self.forest[r] == nil then self.forest[r] = {} end - if self.forest[r][c] == nil then - self.forest[r][c] = { r, c, 1 } +local function makeset(self, x) + if self.forest[x] == nil then + self.forest[x] = { x, 1 } end end -local function find(self, r, c) - if self.forest[r] == nil then self.forest[r] = {} end - local r0, c0, s0 = unpack(self.forest[r][c] or {}) - if r0 == r and c0 == c then - return r0, c0, s0 +local function find(self, x) + local x0, s0 = unpack(self.forest[x] or {}) + if x == x0 then + return x0, s0 else - self.forest[r][c] = { self:find(r0, c0) } - return unpack(self.forest[r][c]) + x0, s0 = self:find(x0) + self.forest[x] = { x0, s0 } + return x0, s0 end end -local function union(self, ar, ac, br, bc) - a0r, a0c, a0s = self:find(ar, ac) - b0r, b0c, b0s = self:find(br, bc) - if a0r == b0r and a0c == b0c then return end - if a0s < b0s then - a0r, a0c, a0s, b0r, b0c, b0s = b0r, b0c, b0s, a0r, a0c, a0s +local function union(self, x, y) + x0, xs = self:find(x) + y0, ys = self:find(y) + if x0 == y0 then return end + if xs < ys then + x0, y0 = y0, x0 end - self.forest[b0r][b0c] = { a0r, a0c, a0s + b0s } - self.forest[a0r][a0c] = { a0r, a0c, a0s + b0s } + self.forest[x0] = { x0, xs + ys } + self.forest[y0] = { x0, xs + ys } end -local function sizes(self) - local sizes = {} - for r, row in pairs(self.forest) do - for c, col in pairs(row) do - if r == col[1] and c == col[2] then - table.insert(sizes, col[3]) - end +function unions(self) + local index, item + return function() + repeat + index, item = next(self.forest, index) + until item == nil or index == item[1] + if item ~= nil then + return unpack(item) + else + return nil end end - return sizes end function unionfind() @@ -46,6 +47,6 @@ function unionfind() makeset = makeset, find = find, union = union, - sizes = sizes, + unions = unions, } end