#!/usr/bin/env luajit require("utils") function parse(line) local number = {} local depth = 0 for char in chars(line) do if char == "[" then depth = depth + 1 elseif char == "]" then depth = depth - 1 elseif char ~= "," then table.insert(number, { d=depth, n=tonumber(char) }) end end return number end function add(a, b) local sum = {} for i, digit in ipairs(a) do table.insert(sum, { d = digit.d + 1, n = digit.n }) end for i, digit in ipairs(b) do table.insert(sum, { d = digit.d + 1, n = digit.n }) end return sum end function explode(a) local new = {} local exploded = false local i = 1 while i <= #a do if exploded or a[i].d <= 4 then table.insert(new, a[i]) i = i + 1 else exploded = true if #new > 0 then new[#new].n = new[#new].n + a[i].n end table.insert(new, { d = a[i].d - 1, n = 0 }) if #a >= i + 2 then table.insert(new, { d = a[i + 2].d, n = a[i + 1].n + a[i + 2].n }) end i = i + 3 end end return exploded, new end -- printtable({ explode(parse("[[[[[9,8],1],2],3],4]")) }) -- printtable({ explode(parse("[7,[6,[5,[4,[3,2]]]]]")) }) -- printtable({ explode(parse("[[6,[5,[4,[3,2]]]],1]")) }) -- printtable({ explode(parse("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]")) }) -- printtable({ explode(parse("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]")) }) function split(a) local new = {} local split = false for i, digit in ipairs(a) do if split or digit.n < 10 then table.insert(new, digit) else split = true table.insert(new, { d = digit.d + 1, n = math.floor(digit.n / 2) }) table.insert(new, { d = digit.d + 1, n = math.ceil(digit.n / 2) }) end end return split, new end function reduce(a) local changed = true repeat changed, a = explode(a) if not changed then changed, a = split(a) end until not changed return a end function reconstruct(number) local function go(d, c) local left, right if number[c].d > d then left, c = go(d + 1, c) right, c = go(d + 1, c) return { left, right }, c else return number[c].n, c + 1 end end return go(0, 1) end function printnumber(a) local function go(a) if type(a) == "number" then io.write(tostring(a)) else io.write("[") go(a[1]) io.write(",") go(a[2]) io.write("]") end end go(reconstruct(a)) io.write("\n") end -- printnumber(reduce(parse("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]"))) -- printnumber(reduce(add(parse("[[[[4,3],4],4],[7,[[8,4],9]]]"),parse("[1,1]")))) -- printnumber(parse("[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]")) -- printtable(parse("[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]")) -- printnumber(add(parse("[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]"),parse("[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]"))) -- printnumber(reduce(add(parse("[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]"),parse("[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]")))) function magnitude(a) local function go(a) if type(a) == "number" then return a else return 3 * go(a[1]) + 2 * go(a[2]) end end return go(reconstruct(a)) end local f = io.open(arg[3]) local number = parse(f:read("*l")) local line = f:read("*l") while line do number = reduce(add(number, parse(line))) line = f:read("*l") end print(magnitude(number))