add unionfind to day 9

This commit is contained in:
Felix Van der Jeugt 2021-12-09 09:59:20 +01:00
parent 9c8cd8e7a9
commit 5df13ab159
No known key found for this signature in database
GPG Key ID: 58B209295023754D
2 changed files with 60 additions and 42 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/env luajit #!/usr/bin/env luajit
require("utils") require("utils")
require("unionfind")
local hm = {} local hm = {}
for line in io.lines(arg[3]) do for line in io.lines(arg[3]) do
@ -14,58 +15,24 @@ local function get(r, c)
return (hm[r] or {})[c] or 9 return (hm[r] or {})[c] or 9
end end
local union = {} local union = unionfind()
for r, row in pairs(hm) do for r, row in pairs(hm) do
for c, col in pairs(row) do for c, col in pairs(row) do
if col ~= 9 then if col ~= 9 then
if get(r - 1, c) ~= 9 then if get(r - 1, c) ~= 9 then
if union[r] == nil then union[r] = {} end union:makeset(r, c)
union[r][c] = { r - 1, c } union:makeset(r - 1, c)
union:union(r, c, r - 1, c)
end end
if get(r, c - 1) ~= 9 then if get(r, c - 1) ~= 9 then
if union[r] == nil then union[r] = {} end union:makeset(r, c)
if union[r][c] ~= nil then union:makeset(r, c - 1)
local r0, c0 = r, c - 1 union:union(r, c, r, c - 1)
while type(union[r0][c0]) == "table" do
r0, c0 = unpack(union[r0][c0])
if union[r0] == nil then union[r0] = {} end
end
local r1, c1 = r, c
while type(union[r1][c1]) == "table" do
r1, c1 = unpack(union[r1][c1])
if union[r1] == nil then union[r1] = {} end
end
if r0 ~= r1 or c0 ~= c1 then
union[r1][c1] = { r0, c0 }
end
else
union[r][c] = { r, c - 1 }
end
end end
end end
end end
end end
local sums = {} local sizes = union:sizes()
for r, row in pairs(hm) do
for c, col in pairs(row) do
if col ~= 9 then
local r0, c0 = r, c
while type(union[r0][c0]) == "table" do
r0, c0 = unpack(union[r0][c0])
end
if sums[r0] == nil then sums[r0] = {} end
sums[r0][c0] = (sums[r0][c0] or 0) + 1
end
end
end
local sizes = {}
for r, row in pairs(sums) do
for c, col in pairs(row) do
table.insert(sizes, col)
end
end
table.sort(sizes) table.sort(sizes)
print(sizes[#sizes] * sizes[#sizes - 1] * sizes[#sizes - 2]) print(sizes[#sizes] * sizes[#sizes - 1] * sizes[#sizes - 2])

51
unionfind.lua Normal file
View File

@ -0,0 +1,51 @@
local function makeset(self, r, c)
if self.forest[r] == nil then self.forest[r] = {} end
if self.forest[r][c] == nil then
self.forest[r][c] = { r, c, 1 }
end
end
local function find(self, r, c)
if self.forest[r] == nil then self.forest[r] = {} end
local r0, c0, s0 = unpack(self.forest[r][c] or {})
if r0 == r and c0 == c then
return r0, c0, s0
else
self.forest[r][c] = { self:find(r0, c0) }
return unpack(self.forest[r][c])
end
end
local function union(self, ar, ac, br, bc)
a0r, a0c, a0s = self:find(ar, ac)
b0r, b0c, b0s = self:find(br, bc)
if a0r == b0r and a0c == b0c then return end
if a0s < b0s then
a0r, a0c, a0s, b0r, b0c, b0s = b0r, b0c, b0s, a0r, a0c, a0s
end
self.forest[b0r][b0c] = { a0r, a0c, a0s + b0s }
self.forest[a0r][a0c] = { a0r, a0c, a0s + b0s }
end
local function sizes(self)
local sizes = {}
for r, row in pairs(self.forest) do
for c, col in pairs(row) do
if r == col[1] and c == col[2] then
table.insert(sizes, col[3])
end
end
end
return sizes
end
function unionfind()
local forest = {}
return {
forest = forest,
makeset = makeset,
find = find,
union = union,
sizes = sizes,
}
end