adventofcode-2021/unionfind.lua

53 lines
899 B
Lua

local function makeset(self, x)
if self.forest[x] == nil then
self.forest[x] = { x, 1 }
end
end
local function find(self, x)
local x0, s0 = unpack(self.forest[x] or {})
if x == x0 then
return x0, s0
else
x0, s0 = self:find(x0)
self.forest[x] = { x0, s0 }
return x0, s0
end
end
local function union(self, x, y)
x0, xs = self:find(x)
y0, ys = self:find(y)
if x0 == y0 then return end
if xs < ys then
x0, y0 = y0, x0
end
self.forest[x0] = { x0, xs + ys }
self.forest[y0] = { x0, xs + ys }
end
function unions(self)
local index, item
return function()
repeat
index, item = next(self.forest, index)
until item == nil or index == item[1]
if item ~= nil then
return unpack(item)
else
return nil
end
end
end
function unionfind()
local forest = {}
return {
forest = forest,
makeset = makeset,
find = find,
union = union,
unions = unions,
}
end