adventofcode-2021/day16/part1.lua

85 lines
2.1 KiB
Lua
Executable File

#!/usr/bin/env luajit
require("utils")
local hex = io.open(arg[3]):read("*a")
--local hex = "8A004A801A8002F478"
local hexmap = {
["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111",
["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011",
["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111",
}
local bytes = {}
for char in chars(hex) do
if hexmap[char] then table.insert(bytes, hexmap[char]) end
end
local bitstring = table.concat(bytes)
local function literal(bitstring)
local v = 0
while char(bitstring, 1) == "1" do
v = v * 16 + bin2dec(bitstring:sub(2, 5))
bitstring = bitstring:sub(6)
end
v = v * 16 + bin2dec(bitstring:sub(2, 5))
return v, bitstring:sub(6)
end
local function countoperator(bitstring)
local count = bin2dec(bitstring:sub(1, 11))
bitstring = bitstring:sub(12)
local p, packets = nil, {}
for i = 1, count do
p, bitstring = packet(bitstring)
table.insert(packets, p)
end
return packets, bitstring
end
local function lengthoperator(bitstring)
local length = bin2dec(bitstring:sub(1, 15))
bitstring = bitstring:sub(16)
local target = #bitstring - length
local p, packets = nil, {}
while #bitstring > target do
p, bitstring = packet(bitstring)
table.insert(packets, p)
end
return packets, bitstring
end
local function operator(bitstring)
if char(bitstring, 1) == "1" then
return countoperator(bitstring:sub(2))
else
return lengthoperator(bitstring:sub(2))
end
end
function packet(bitstring)
local version = bin2dec(bitstring:sub(1, 3))
local pactype = bin2dec(bitstring:sub(4, 6))
local content
if pactype == 4
then content, bitstring = literal(bitstring:sub(7))
else content, bitstring = operator(bitstring:sub(7))
end
return { version = version, pactype = pactype, content = content }, bitstring
end
function sumversions(packet)
if packet.pactype == 4 then
return packet.version
else
local sum = packet.version
for i, p in pairs(packet.content) do
sum = sum + sumversions(p)
end
return sum
end
end
print(sumversions(packet(bitstring)))