adventofcode-2021/unionfind.lua

52 lines
1.1 KiB
Lua

local function makeset(self, r, c)
if self.forest[r] == nil then self.forest[r] = {} end
if self.forest[r][c] == nil then
self.forest[r][c] = { r, c, 1 }
end
end
local function find(self, r, c)
if self.forest[r] == nil then self.forest[r] = {} end
local r0, c0, s0 = unpack(self.forest[r][c] or {})
if r0 == r and c0 == c then
return r0, c0, s0
else
self.forest[r][c] = { self:find(r0, c0) }
return unpack(self.forest[r][c])
end
end
local function union(self, ar, ac, br, bc)
a0r, a0c, a0s = self:find(ar, ac)
b0r, b0c, b0s = self:find(br, bc)
if a0r == b0r and a0c == b0c then return end
if a0s < b0s then
a0r, a0c, a0s, b0r, b0c, b0s = b0r, b0c, b0s, a0r, a0c, a0s
end
self.forest[b0r][b0c] = { a0r, a0c, a0s + b0s }
self.forest[a0r][a0c] = { a0r, a0c, a0s + b0s }
end
local function sizes(self)
local sizes = {}
for r, row in pairs(self.forest) do
for c, col in pairs(row) do
if r == col[1] and c == col[2] then
table.insert(sizes, col[3])
end
end
end
return sizes
end
function unionfind()
local forest = {}
return {
forest = forest,
makeset = makeset,
find = find,
union = union,
sizes = sizes,
}
end