about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorelioat <hi@eli.li>2024-06-09 12:07:13 -0400
committerelioat <hi@eli.li>2024-06-09 12:07:13 -0400
commit1b5475e67cd31d6837a02e94e110f2091dc059d0 (patch)
treea67835430bcc1a8b9377cfa5fd5d2e9838192705
parent1b6ac89b5051e7f010cb9cbd23cc5465fc950b64 (diff)
downloadtour-1b5475e67cd31d6837a02e94e110f2091dc059d0.tar.gz
*:
-rw-r--r--lua/chupacabra/chupacabra.lua116
-rw-r--r--lua/chupacabra/test_chupacabra.lua9
2 files changed, 106 insertions, 19 deletions
diff --git a/lua/chupacabra/chupacabra.lua b/lua/chupacabra/chupacabra.lua
index e4211e5..963d17d 100644
--- a/lua/chupacabra/chupacabra.lua
+++ b/lua/chupacabra/chupacabra.lua
@@ -50,32 +50,114 @@ function chupacabra.evaluate(tokens, context)
             local b = table.remove(stack)
             local a = table.remove(stack)
             table.insert(stack, a / b)
-        elseif token == "map+" then -- FIXME: fix all map functions to either add 2 arrays or spread
+        elseif token == "@+" then
             local a = table.remove(stack)
-            for i, v in ipairs(a) do
-                a[i] = v + 1
+            local b = table.remove(stack)
+            if type(a) == "table" and type(b) == "table" then
+            if #a ~= #b then
+                error("Arrays must have equal length")
+            end
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] + b[i])
+            end
+            table.insert(stack, result)
+            elseif type(a) == "table" and type(b) == "number" then
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] + b)
+            end
+            table.insert(stack, result)
+            elseif type(a) == "number" and type(b) == "table" then
+            local result = {}
+            for i = 1, #b do
+                table.insert(result, a + b[i])
+            end
+            table.insert(stack, result)
+            else
+            error("Invalid operands for addition")
             end
-            table.insert(stack, a)
-        elseif token == "map*" then
+        elseif token == "@*" then
             local a = table.remove(stack)
-            for i, v in ipairs(a) do
-                a[i] = v * 2
+            local b = table.remove(stack)
+            if type(a) == "table" and type(b) == "table" then
+            if #a ~= #b then
+                error("Arrays must have equal length")
+            end
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] * b[i])
+            end
+            table.insert(stack, result)
+            elseif type(a) == "table" and type(b) == "number" then
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] * b)
+            end
+            table.insert(stack, result)
+            elseif type(a) == "number" and type(b) == "table" then
+            local result = {}
+            for i = 1, #b do
+                table.insert(result, a * b[i])
+            end
+            table.insert(stack, result)
+            else
+            error("Invalid operands for multiplication")
             end
-            table.insert(stack, a)
-        elseif token == "map-" then
+        elseif token == "@-" then
             local a = table.remove(stack)
-            for i, v in ipairs(a) do
-                a[i] = v - 1
+            local b = table.remove(stack)
+            if type(a) == "table" and type(b) == "table" then
+            if #a ~= #b then
+                error("Arrays must have equal length")
+            end
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] - b[i])
+            end
+            table.insert(stack, result)
+            elseif type(a) == "table" and type(b) == "number" then
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] - b)
+            end
+            table.insert(stack, result)
+            elseif type(a) == "number" and type(b) == "table" then
+            local result = {}
+            for i = 1, #b do
+                table.insert(result, a - b[i])
+            end
+            table.insert(stack, result)
+            else
+            error("Invalid operands for subtraction")
             end
-            table.insert(stack, a)
-        elseif token == "map/" then
+        elseif token == "@/" then
             local a = table.remove(stack)
-            for i, v in ipairs(a) do
-                a[i] = v / 2
+            local b = table.remove(stack)
+            if type(a) == "table" and type(b) == "table" then
+            if #a ~= #b then
+                error("Arrays must have equal length")
+            end
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] / b[i])
+            end
+            table.insert(stack, result)
+            elseif type(a) == "table" and type(b) == "number" then
+            local result = {}
+            for i = 1, #a do
+                table.insert(result, a[i] / b)
             end
-            table.insert(stack, a)
+            table.insert(stack, result)
+            elseif type(a) == "number" and type(b) == "table" then
+            local result = {}
+            for i = 1, #b do
+                table.insert(result, a / b[i])
+            end
+            table.insert(stack, result)
         else
-            error("invalid token: " .. token)
+            error("Invalid operands for division")
+            end
         end
     end
 
diff --git a/lua/chupacabra/test_chupacabra.lua b/lua/chupacabra/test_chupacabra.lua
index 1eb5a48..b2e5ef2 100644
--- a/lua/chupacabra/test_chupacabra.lua
+++ b/lua/chupacabra/test_chupacabra.lua
@@ -21,11 +21,16 @@ local function test_case(input, expected_output)
     print("Test passed: " .. input .. " => " .. output_str)
 end
 
-
+test_case("[1 1 1] [2 3 4] @+", {3, 4, 5})
+test_case("[1 1 1] 2 @+", {3, 3, 3})
+test_case("[2 3 4] 3 @-", {1, 0, -1})
+test_case("[2 3 4] 3 @*", {6, 9, 12})
+test_case("[2 3 4] [2 3 4] @*", {4, 9, 16})
+test_case("2 [12 6 4] @/", {6.0, 3.0, 2.0})
+test_case("[2 2 2] [24 12 16] @/", {12.0, 6.0, 8.0})
 test_case("1", 1)  -- 1
 test_case("2 1 pop", 2) -- 2
 test_case("[1 1]", {1, 1})
-test_case("[1 2 3] map+", {2, 3, 4})
 test_case("3 4 +", 7)  -- 3 + 4 = 7
 test_case("5 2 -", 3)  -- 5 - 2 = 3
 test_case("2 3 *", 6)  -- 2 * 3 = 6