diff --git a/day24/part1.lua b/day24/part1.lua index 455d8be..01a2ee6 100755 --- a/day24/part1.lua +++ b/day24/part1.lua @@ -1,120 +1,50 @@ #!/usr/bin/env luajit require("utils") -local function trunc(a) - if a < 0 then - return math.ceil(a) - else - return math.floor(a) - end -end - -local map = { w=1, x=2, y=3, z=4 } -local instructions = {} -for line in io.lines(arg[3]) do +local blocks = {} +local current = nil +local f = io.open(arg[3]) +local line = f:read("*l") +while line do local inst, a1, a2 = unpack(split(line, " ")) if inst == "inp" then - table.insert(instructions, function(memory, input) - local i, o = 9, memory[a1] - return function() - if i > 1 then - i = i - 1 - memory[a1] = i - return input * 10 + i - else - memory[a1] = o - end - end - end) + if current ~= nil then + current = string.format("return function(w, x, y, z, input)\n%s\nreturn w, x, y, z\nend", current) + table.insert(blocks, assert(loadstring(current))()) + end + current = a1.." = input" elseif inst == "add" then - table.insert(instructions, function(memory, input) - local done, o = false, memory[a1] - return function() - if not done then - memory[a1] = memory[a1] + (memory[a2] or tonumber(a2)) - done = true - return input - else - memory[a1] = o - end - end - end) + current = string.format("%s\n%s = %s + %s", current, a1, a1, a2) elseif inst == "mul" then - table.insert(instructions, function(memory, input) - local done, o = false, memory[a1] - return function() - if not done then - memory[a1] = memory[a1] * (memory[a2] or tonumber(a2)) - done = true - return input - else - memory[a1] = o - end - end - end) + current = string.format("%s\n%s = %s * %s", current, a1, a1, a2) elseif inst == "div" then - table.insert(instructions, function(memory, input) - local done, o = false, memory[a1] - return function() - if not done and (memory[a2] or tonumber(a2)) ~= 0 then - memory[a1] = trunc(memory[a1] / (memory[a2] or tonumber(a2))) - done = true - return input - else - memory[a1] = o - end - end - end) + current = string.format("%s\nif %s == 0 then return end\n%s = math.modf(%s / %s)", current, a2, a1, a1, a2) elseif inst == "mod" then - table.insert(instructions, function(memory, input) - local done, o = false, memory[a1] - return function() - local v = memory[a2] or tonumber(a2) - if not done and memory[a1] >= 0 and v > 0 then - memory[a1] = memory[a1] % v - done = true - return input - else - memory[a1] = o - end - end - end) + current = string.format("%s\nif %s < 0 or %s <= 0 then return end\n%s = %s %% %s", current, a1, a2, a1, a1, a2) elseif inst == "eql" then - table.insert(instructions, function(memory, input) - local done, o = false, memory[a1] - return function() - if not done then - if memory[a1] == (memory[a2] or tonumber(a2)) then - memory[a1] = 1 - else - memory[a1] = 0 - end - done = true - return input - else - memory[a1] = o - end - end - end) + current = string.format("%s\nif %s == %s then %s = 1 else %s = 0 end", current, a1, a2, a1, a1) end + line = f:read("*l") +end +if current ~= nil then + current = string.format("return function(w, x, y, z, input)\n%s\nreturn w, x, y, z\nend", current) + table.insert(blocks, assert(loadstring(current))()) end -local maximum = 0 -local function backtrack(i, memory, input) - if i == #instructions then - if memory.z == 0 and input > maximum then - print(input) - print(os.date()) - maximum = input +local function backtrack(i, w, x, y, z, output) + if i > #blocks then + if z == 0 then + print(string.format("%d", output)) + os.exit() end else - -- io.write("i "..tostring(i).." ") - -- printtable(memory) - for newinput in instructions[i](memory, input) do - backtrack(i + 1, memory, newinput) + if i < 7 then print(output) end + for n = 9, 1, -1 do + local w0, x0, y0, z0 = blocks[i](w, x, y, z, n) + if w0 then + backtrack(i + 1, w0, x0, y0, z0, output * 10 + n) + end end end end - -backtrack(1, { w=0, x=0, y=0, z=0 }, 0) -print(maximum) +backtrack(1, 0, 0, 0, 0, 0) diff --git a/day24/part2.lua b/day24/part2.lua new file mode 100755 index 0000000..3514fb4 --- /dev/null +++ b/day24/part2.lua @@ -0,0 +1,50 @@ +#!/usr/bin/env luajit +require("utils") + +local blocks = {} +local current = nil +local f = io.open(arg[3]) +local line = f:read("*l") +while line do + local inst, a1, a2 = unpack(split(line, " ")) + if inst == "inp" then + if current ~= nil then + current = string.format("return function(w, x, y, z, input)\n%s\nreturn w, x, y, z\nend", current) + table.insert(blocks, assert(loadstring(current))()) + end + current = a1.." = input" + elseif inst == "add" then + current = string.format("%s\n%s = %s + %s", current, a1, a1, a2) + elseif inst == "mul" then + current = string.format("%s\n%s = %s * %s", current, a1, a1, a2) + elseif inst == "div" then + current = string.format("%s\nif %s == 0 then return end\n%s = math.modf(%s / %s)", current, a2, a1, a1, a2) + elseif inst == "mod" then + current = string.format("%s\nif %s < 0 or %s <= 0 then return end\n%s = %s %% %s", current, a1, a2, a1, a1, a2) + elseif inst == "eql" then + current = string.format("%s\nif %s == %s then %s = 1 else %s = 0 end", current, a1, a2, a1, a1) + end + line = f:read("*l") +end +if current ~= nil then + current = string.format("return function(w, x, y, z, input)\n%s\nreturn w, x, y, z\nend", current) + table.insert(blocks, assert(loadstring(current))()) +end + +local function backtrack(i, w, x, y, z, output) + if i > #blocks then + if z == 0 then + print(string.format("%d", output)) + os.exit() + end + else + if i < 7 then print(output) end + for n = 1, 9 do + local w0, x0, y0, z0 = blocks[i](w, x, y, z, n) + if w0 then + backtrack(i + 1, w0, x0, y0, z0, output * 10 + n) + end + end + end +end +backtrack(1, 0, 0, 0, 0, 0)