diff options
Diffstat (limited to 'js/scripting-lang/baba-yaga-c/src/parser.c')
-rw-r--r-- | js/scripting-lang/baba-yaga-c/src/parser.c | 2140 |
1 files changed, 2140 insertions, 0 deletions
diff --git a/js/scripting-lang/baba-yaga-c/src/parser.c b/js/scripting-lang/baba-yaga-c/src/parser.c new file mode 100644 index 0000000..8531b5a --- /dev/null +++ b/js/scripting-lang/baba-yaga-c/src/parser.c @@ -0,0 +1,2140 @@ +/** + * @file parser.c + * @brief Parser implementation for Baba Yaga + * @author eli_oat + * @version 0.0.1 + * @date 2025 + * + * This file implements the parser for the Baba Yaga language. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <assert.h> + +#include "baba_yaga.h" + +/* ============================================================================ + * Token Types (from lexer.c) + * ============================================================================ */ + +typedef enum { + TOKEN_EOF, + TOKEN_NUMBER, + TOKEN_STRING, + TOKEN_BOOLEAN, + TOKEN_IDENTIFIER, + TOKEN_KEYWORD_WHEN, + TOKEN_KEYWORD_IS, + TOKEN_KEYWORD_THEN, + TOKEN_KEYWORD_AND, + TOKEN_KEYWORD_OR, + TOKEN_KEYWORD_XOR, + TOKEN_KEYWORD_NOT, + TOKEN_KEYWORD_VIA, + TOKEN_OP_PLUS, + TOKEN_OP_MINUS, + TOKEN_OP_UNARY_MINUS, + TOKEN_OP_MULTIPLY, + TOKEN_OP_DIVIDE, + TOKEN_OP_MODULO, + TOKEN_OP_POWER, + TOKEN_OP_EQUALS, + TOKEN_OP_NOT_EQUALS, + TOKEN_OP_LESS, + TOKEN_OP_LESS_EQUAL, + TOKEN_OP_GREATER, + TOKEN_OP_GREATER_EQUAL, + TOKEN_LPAREN, + TOKEN_RPAREN, + TOKEN_LBRACE, + TOKEN_RBRACE, + TOKEN_LBRACKET, + TOKEN_RBRACKET, + TOKEN_COMMA, + TOKEN_COLON, + TOKEN_SEMICOLON, + TOKEN_ARROW, + TOKEN_DOT, + TOKEN_FUNCTION_REF, + TOKEN_IO_IN, + TOKEN_IO_OUT, + TOKEN_IO_ASSERT, + TOKEN_COMMENT +} TokenType; + +typedef struct { + TokenType type; + char* lexeme; + int line; + int column; + union { + double number; + bool boolean; + } literal; +} Token; + +/* ============================================================================ + * AST Node Types + * ============================================================================ */ + +/* NodeType enum is now defined in baba_yaga.h */ + +/* ============================================================================ + * AST Node Structure + * ============================================================================ */ + +struct ASTNode { + NodeType type; + int line; + int column; + union { + Value literal; + char* identifier; + struct { + struct ASTNode* left; + struct ASTNode* right; + char* operator; + } binary; + struct { + struct ASTNode* operand; + char* operator; + } unary; + struct { + struct ASTNode* function; + struct ASTNode** arguments; + int arg_count; + } function_call; + struct { + char* name; + struct ASTNode** parameters; + int param_count; + struct ASTNode* body; + } function_def; + struct { + char* name; + struct ASTNode* value; + } variable_decl; + struct { + struct ASTNode* test; + struct ASTNode** patterns; + int pattern_count; + } when_expr; + struct { + struct ASTNode* test; + struct ASTNode* result; + } when_pattern; + struct { + struct ASTNode** elements; + int element_count; + } table; + struct { + struct ASTNode* object; + struct ASTNode* key; + } table_access; + struct { + char* operation; + struct ASTNode* argument; + } io_operation; + struct { + struct ASTNode** statements; + int statement_count; + } sequence; + } data; +}; + +/* ============================================================================ + * Parser Structure + * ============================================================================ */ + +typedef struct { + Token** tokens; + int token_count; + int current; + bool has_error; + char* error_message; +} Parser; + +/* ============================================================================ + * AST Node Management + * ============================================================================ */ + +/** + * @brief Create a literal node + * + * @param value Literal value + * @param line Line number + * @param column Column number + * @return New literal node + */ +static ASTNode* ast_literal_node(Value value, int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_LITERAL; + node->line = line; + node->column = column; + node->data.literal = value; + + return node; +} + +/** + * @brief Create an identifier node + * + * @param identifier Identifier name + * @param line Line number + * @param column Column number + * @return New identifier node + */ +static ASTNode* ast_identifier_node(const char* identifier, int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_IDENTIFIER; + node->line = line; + node->column = column; + node->data.identifier = strdup(identifier); + + return node; +} + +/** + * @brief Create a function call node + * + * @param function Function expression + * @param arguments Array of argument expressions + * @param arg_count Number of arguments + * @param line Line number + * @param column Column number + * @return New function call node + */ +static ASTNode* ast_function_call_node(ASTNode* function, ASTNode** arguments, + int arg_count, int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_FUNCTION_CALL; + node->line = line; + node->column = column; + node->data.function_call.function = function; + node->data.function_call.arguments = arguments; + node->data.function_call.arg_count = arg_count; + + return node; +} + +/** + * @brief Create a binary operator node + * + * @param left Left operand + * @param right Right operand + * @param operator Operator name + * @param line Line number + * @param column Column number + * @return New binary operator node + */ +static ASTNode* ast_binary_op_node(ASTNode* left, ASTNode* right, + const char* operator, int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_BINARY_OP; + node->line = line; + node->column = column; + node->data.binary.left = left; + node->data.binary.right = right; + node->data.binary.operator = strdup(operator); + + return node; +} + +/** + * @brief Create a unary operator node (translated to function call) + * + * @param operand Operand expression + * @param operator Operator name + * @param line Line number + * @param column Column number + * @return New function call node representing the operator + */ +static ASTNode* ast_unary_op_node(ASTNode* operand, const char* operator, + int line, int column) { + /* Create simple function call: operator(operand) */ + ASTNode* operator_node = ast_identifier_node(operator, line, column); + if (operator_node == NULL) { + return NULL; + } + + ASTNode** args = malloc(1 * sizeof(ASTNode*)); + if (args == NULL) { + free(operator_node); + return NULL; + } + args[0] = operand; + + return ast_function_call_node(operator_node, args, 1, line, column); +} + +/** + * @brief Create a sequence node + * + * @param statements Array of statement nodes + * @param statement_count Number of statements + * @param line Line number + * @param column Column number + * @return New sequence node + */ +static ASTNode* ast_sequence_node(ASTNode** statements, int statement_count, + int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_SEQUENCE; + node->line = line; + node->column = column; + node->data.sequence.statements = statements; + node->data.sequence.statement_count = statement_count; + + return node; +} + +/** + * @brief Create a when expression node + * + * @param test Test expression + * @param patterns Array of pattern nodes + * @param pattern_count Number of patterns + * @param line Line number + * @param column Column number + * @return New when expression node + */ +static ASTNode* ast_when_expr_node(ASTNode* test, ASTNode** patterns, + int pattern_count, int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_WHEN_EXPR; + node->line = line; + node->column = column; + node->data.when_expr.test = test; + node->data.when_expr.patterns = patterns; + node->data.when_expr.pattern_count = pattern_count; + + return node; +} + +/** + * @brief Create a when pattern node + * + * @param test Pattern test expression + * @param result Result expression + * @param line Line number + * @param column Column number + * @return New when pattern node + */ +static ASTNode* ast_when_pattern_node(ASTNode* test, ASTNode* result, + int line, int column) { + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + return NULL; + } + + node->type = NODE_WHEN_PATTERN; + node->line = line; + node->column = column; + node->data.when_pattern.test = test; + node->data.when_pattern.result = result; + + return node; +} + +/** + * @brief Destroy an AST node + * + * @param node Node to destroy + */ +static void ast_destroy_node(ASTNode* node) { + if (node == NULL) { + return; + } + + switch (node->type) { + case NODE_IDENTIFIER: + free(node->data.identifier); + break; + case NODE_FUNCTION_CALL: + for (int i = 0; i < node->data.function_call.arg_count; i++) { + ast_destroy_node(node->data.function_call.arguments[i]); + } + free(node->data.function_call.arguments); + ast_destroy_node(node->data.function_call.function); + break; + case NODE_FUNCTION_DEF: + for (int i = 0; i < node->data.function_def.param_count; i++) { + ast_destroy_node(node->data.function_def.parameters[i]); + } + free(node->data.function_def.parameters); + free(node->data.function_def.name); + ast_destroy_node(node->data.function_def.body); + break; + case NODE_VARIABLE_DECL: + free(node->data.variable_decl.name); + ast_destroy_node(node->data.variable_decl.value); + break; + case NODE_WHEN_EXPR: + ast_destroy_node(node->data.when_expr.test); + for (int i = 0; i < node->data.when_expr.pattern_count; i++) { + ast_destroy_node(node->data.when_expr.patterns[i]); + } + free(node->data.when_expr.patterns); + break; + case NODE_WHEN_PATTERN: + ast_destroy_node(node->data.when_pattern.test); + ast_destroy_node(node->data.when_pattern.result); + break; + case NODE_TABLE: + for (int i = 0; i < node->data.table.element_count; i++) { + ast_destroy_node(node->data.table.elements[i]); + } + free(node->data.table.elements); + break; + case NODE_TABLE_ACCESS: + ast_destroy_node(node->data.table_access.object); + ast_destroy_node(node->data.table_access.key); + break; + case NODE_IO_OPERATION: + free(node->data.io_operation.operation); + ast_destroy_node(node->data.io_operation.argument); + break; + case NODE_SEQUENCE: + for (int i = 0; i < node->data.sequence.statement_count; i++) { + ast_destroy_node(node->data.sequence.statements[i]); + } + free(node->data.sequence.statements); + break; + default: + /* No cleanup needed for other types */ + break; + } + + free(node); +} + +/* ============================================================================ + * Parser Functions + * ============================================================================ */ + +/** + * @brief Create a new parser + * + * @param tokens Array of tokens + * @param token_count Number of tokens + * @return New parser instance, or NULL on failure + */ +static Parser* parser_create(Token** tokens, int token_count) { + Parser* parser = malloc(sizeof(Parser)); + if (parser == NULL) { + return NULL; + } + + parser->tokens = tokens; + parser->token_count = token_count; + parser->current = 0; + parser->has_error = false; + parser->error_message = NULL; + + return parser; +} + +/** + * @brief Destroy a parser + * + * @param parser Parser to destroy + */ +static void parser_destroy(Parser* parser) { + if (parser == NULL) { + return; + } + + if (parser->error_message != NULL) { + free(parser->error_message); + } + + free(parser); +} + +/** + * @brief Set parser error + * + * @param parser Parser instance + * @param message Error message + */ +static void parser_set_error(Parser* parser, const char* message) { + if (parser == NULL) { + return; + } + + parser->has_error = true; + if (parser->error_message != NULL) { + free(parser->error_message); + } + parser->error_message = strdup(message); +} + +/** + * @brief Check if we're at the end of tokens + * + * @param parser Parser instance + * @return true if at end, false otherwise + */ +static bool parser_is_at_end(const Parser* parser) { + return parser->current >= parser->token_count; +} + +/** + * @brief Peek at current token + * + * @param parser Parser instance + * @return Current token, or NULL if at end + */ +static Token* parser_peek(const Parser* parser) { + if (parser_is_at_end(parser)) { + return NULL; + } + return parser->tokens[parser->current]; +} + +/** + * @brief Peek at next token + * + * @param parser Parser instance + * @return Next token, or NULL if at end + */ +static Token* parser_peek_next(const Parser* parser) { + if (parser->current + 1 >= parser->token_count) { + return NULL; + } + return parser->tokens[parser->current + 1]; +} + +/** + * @brief Advance to next token + * + * @param parser Parser instance + * @return Token that was advanced over + */ +static Token* parser_advance(Parser* parser) { + if (parser_is_at_end(parser)) { + return NULL; + } + return parser->tokens[parser->current++]; +} + +/** + * @brief Check if current token matches expected type + * + * @param parser Parser instance + * @param type Expected token type + * @return true if matches, false otherwise + */ +static bool parser_check(const Parser* parser, TokenType type) { + if (parser_is_at_end(parser)) { + return false; + } + return parser->tokens[parser->current]->type == type; +} + +/** + * @brief Consume token of expected type + * + * @param parser Parser instance + * @param type Expected token type + * @param error_message Error message if type doesn't match + * @return Consumed token, or NULL on error + */ +static Token* parser_consume(Parser* parser, TokenType type, const char* error_message) { + if (parser_check(parser, type)) { + return parser_advance(parser); + } + + parser_set_error(parser, error_message); + return NULL; +} + +/* ============================================================================ + * Expression Parsing (Operator Precedence) + * ============================================================================ */ + +/* Forward declarations */ +static ASTNode* parser_parse_expression(Parser* parser); +static ASTNode* parser_parse_logical(Parser* parser); +/* static ASTNode* parser_parse_composition(Parser* parser); */ +/* static ASTNode* parser_parse_application(Parser* parser); */ +static ASTNode* parser_parse_statement(Parser* parser); +static ASTNode* parser_parse_when_expression(Parser* parser); +static ASTNode* parser_parse_when_pattern(Parser* parser); +static const char* node_type_name(NodeType type); + +/** + * @brief Parse primary expression (literals, identifiers, parentheses) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_primary(Parser* parser) { + Token* token = parser_peek(parser); + if (token == NULL) { + parser_set_error(parser, "Unexpected end of input"); + return NULL; + } + + switch (token->type) { + case TOKEN_NUMBER: { + parser_advance(parser); + return ast_literal_node(baba_yaga_value_number(token->literal.number), + token->line, token->column); + } + case TOKEN_STRING: { + parser_advance(parser); + return ast_literal_node(baba_yaga_value_string(token->lexeme), + token->line, token->column); + } + case TOKEN_BOOLEAN: { + parser_advance(parser); + return ast_literal_node(baba_yaga_value_boolean(token->literal.boolean), + token->line, token->column); + } + case TOKEN_IDENTIFIER: { + parser_advance(parser); + /* Special handling for wildcard pattern */ + if (strcmp(token->lexeme, "_") == 0) { + /* Create a special wildcard literal */ + return ast_literal_node(baba_yaga_value_string("_"), token->line, token->column); + } + return ast_identifier_node(token->lexeme, token->line, token->column); + } + case TOKEN_IO_IN: + case TOKEN_IO_OUT: + case TOKEN_IO_ASSERT: { + parser_advance(parser); + /* IO operations are treated as function calls - strip the ".." prefix */ + const char* func_name = token->lexeme + 2; /* Skip ".." */ + + /* For ..assert, parse the entire expression as a single argument */ + if (strcmp(func_name, "assert") == 0) { + /* Parse the assertion expression */ + ASTNode* assertion_expr = parser_parse_expression(parser); + if (assertion_expr == NULL) { + return NULL; + } + + /* Create function call with the assertion expression as argument */ + ASTNode** args = malloc(1 * sizeof(ASTNode*)); + if (args == NULL) { + ast_destroy_node(assertion_expr); + return NULL; + } + args[0] = assertion_expr; + + ASTNode* func_node = ast_identifier_node(func_name, token->line, token->column); + if (func_node == NULL) { + free(args); + ast_destroy_node(assertion_expr); + return NULL; + } + + return ast_function_call_node(func_node, args, 1, token->line, token->column); + } + + return ast_identifier_node(func_name, token->line, token->column); + } + case TOKEN_KEYWORD_WHEN: { + return parser_parse_when_expression(parser); + } + case TOKEN_FUNCTION_REF: { + parser_advance(parser); + + /* Check if this is @(expression) syntax */ + if (!parser_is_at_end(parser) && parser_peek(parser)->type == TOKEN_LPAREN) { + parser_advance(parser); /* consume '(' */ + + /* Parse the expression inside parentheses */ + ASTNode* expr = parser_parse_expression(parser); + if (expr == NULL) { + return NULL; + } + + /* Expect closing parenthesis */ + if (!parser_consume(parser, TOKEN_RPAREN, "Expected ')' after expression")) { + ast_destroy_node(expr); + return NULL; + } + + /* Return the expression as-is (it will be evaluated when used as an argument) */ + return expr; + } + + /* Handle @function_name syntax */ + ASTNode* func_node = ast_identifier_node(token->lexeme, token->line, token->column); + if (func_node == NULL) { + return NULL; + } + + /* Check if this function reference is followed by arguments */ + if (!parser_is_at_end(parser)) { + Token* next_token = parser_peek(parser); + if (next_token != NULL && + next_token->type != TOKEN_OP_PLUS && + next_token->type != TOKEN_OP_MINUS && + next_token->type != TOKEN_OP_MULTIPLY && + next_token->type != TOKEN_OP_DIVIDE && + next_token->type != TOKEN_OP_MODULO && + next_token->type != TOKEN_OP_POWER && + next_token->type != TOKEN_OP_EQUALS && + next_token->type != TOKEN_OP_NOT_EQUALS && + next_token->type != TOKEN_OP_LESS && + next_token->type != TOKEN_OP_LESS_EQUAL && + next_token->type != TOKEN_OP_GREATER && + next_token->type != TOKEN_OP_GREATER_EQUAL && + next_token->type != TOKEN_RPAREN && + next_token->type != TOKEN_RBRACE && + next_token->type != TOKEN_RBRACKET && + next_token->type != TOKEN_SEMICOLON && + next_token->type != TOKEN_COMMA && + next_token->type != TOKEN_EOF) { + + /* Parse arguments for this function call */ + ASTNode** args = NULL; + int arg_count = 0; + + while (!parser_is_at_end(parser)) { + Token* arg_token = parser_peek(parser); + if (arg_token == NULL) { + break; + } + + /* Stop if we hit an operator or delimiter */ + if (arg_token->type == TOKEN_OP_PLUS || + arg_token->type == TOKEN_OP_MINUS || + arg_token->type == TOKEN_OP_MULTIPLY || + arg_token->type == TOKEN_OP_DIVIDE || + arg_token->type == TOKEN_OP_MODULO || + arg_token->type == TOKEN_OP_POWER || + arg_token->type == TOKEN_OP_EQUALS || + arg_token->type == TOKEN_OP_NOT_EQUALS || + arg_token->type == TOKEN_OP_LESS || + arg_token->type == TOKEN_OP_LESS_EQUAL || + arg_token->type == TOKEN_OP_GREATER || + arg_token->type == TOKEN_OP_GREATER_EQUAL || + arg_token->type == TOKEN_RPAREN || + arg_token->type == TOKEN_RBRACE || + arg_token->type == TOKEN_RBRACKET || + arg_token->type == TOKEN_SEMICOLON || + arg_token->type == TOKEN_COMMA || + arg_token->type == TOKEN_EOF) { + break; + } + + /* Parse argument */ + ASTNode* arg = parser_parse_primary(parser); + if (arg == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(func_node); + return NULL; + } + + /* Add to arguments array */ + ASTNode** new_args = realloc(args, (arg_count + 1) * sizeof(ASTNode*)); + if (new_args == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(arg); + ast_destroy_node(func_node); + return NULL; + } + args = new_args; + args[arg_count] = arg; + arg_count++; + } + + /* Create function call with the arguments */ + if (arg_count > 0) { + ASTNode* func_call = ast_function_call_node(func_node, args, arg_count, func_node->line, func_node->column); + if (func_call == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(func_node); + return NULL; + } + return func_call; + } + } + } + + return func_node; + } + case TOKEN_LPAREN: { + parser_advance(parser); /* consume '(' */ + ASTNode* expr = parser_parse_expression(parser); + if (expr == NULL) { + return NULL; + } + + if (!parser_consume(parser, TOKEN_RPAREN, "Expected ')' after expression")) { + ast_destroy_node(expr); + return NULL; + } + + return expr; + } + case TOKEN_OP_UNARY_MINUS: { + parser_advance(parser); /* consume '-' */ + ASTNode* operand = parser_parse_primary(parser); + if (operand == NULL) { + return NULL; + } + return ast_unary_op_node(operand, "negate", token->line, token->column); + } + case TOKEN_KEYWORD_NOT: { + parser_advance(parser); /* consume 'not' */ + ASTNode* operand = parser_parse_primary(parser); + if (operand == NULL) { + return NULL; + } + return ast_unary_op_node(operand, "not", token->line, token->column); + } + default: + parser_set_error(parser, "Unexpected token in expression"); + return NULL; + } +} + +/** + * @brief Parse function call expression + * + * @param parser Parser instance + * @return Parsed expression node + */ +/* TODO: Re-implement function call parsing at application level */ +/* TODO: Re-implement function call parsing at application level */ + +/** + * @brief Parse power expression (^) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_power(Parser* parser) { + ASTNode* left = parser_parse_primary(parser); + if (left == NULL) { + return NULL; + } + + while (parser_check(parser, TOKEN_OP_POWER)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_primary(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + ASTNode* new_left = ast_binary_op_node(left, right, "pow", op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse multiplicative expression (*, /, %) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_multiplicative(Parser* parser) { + ASTNode* left = parser_parse_power(parser); + if (left == NULL) { + return NULL; + } + + while (parser_check(parser, TOKEN_OP_MULTIPLY) || + parser_check(parser, TOKEN_OP_DIVIDE) || + parser_check(parser, TOKEN_OP_MODULO)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_power(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + const char* operator_name; + switch (op->type) { + case TOKEN_OP_MULTIPLY: operator_name = "multiply"; break; + case TOKEN_OP_DIVIDE: operator_name = "divide"; break; + case TOKEN_OP_MODULO: operator_name = "modulo"; break; + default: operator_name = "unknown"; break; + } + + ASTNode* new_left = ast_binary_op_node(left, right, operator_name, op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse additive expression (+, -) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_additive(Parser* parser) { + ASTNode* left = parser_parse_multiplicative(parser); + if (left == NULL) { + return NULL; + } + + while (parser_check(parser, TOKEN_OP_PLUS) || parser_check(parser, TOKEN_OP_MINUS)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_multiplicative(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + const char* operator_name = (op->type == TOKEN_OP_PLUS) ? "add" : "subtract"; + + ASTNode* new_left = ast_binary_op_node(left, right, operator_name, op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse comparison expression (=, !=, <, <=, >, >=) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_comparison(Parser* parser) { + ASTNode* left = parser_parse_additive(parser); + if (left == NULL) { + return NULL; + } + + while (parser_check(parser, TOKEN_OP_EQUALS) || + parser_check(parser, TOKEN_OP_NOT_EQUALS) || + parser_check(parser, TOKEN_OP_LESS) || + parser_check(parser, TOKEN_OP_LESS_EQUAL) || + parser_check(parser, TOKEN_OP_GREATER) || + parser_check(parser, TOKEN_OP_GREATER_EQUAL)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_additive(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + const char* operator_name; + switch (op->type) { + case TOKEN_OP_EQUALS: operator_name = "equals"; break; + case TOKEN_OP_NOT_EQUALS: operator_name = "not_equals"; break; + case TOKEN_OP_LESS: operator_name = "less"; break; + case TOKEN_OP_LESS_EQUAL: operator_name = "less_equal"; break; + case TOKEN_OP_GREATER: operator_name = "greater"; break; + case TOKEN_OP_GREATER_EQUAL: operator_name = "greater_equal"; break; + default: operator_name = "unknown"; break; + } + + ASTNode* new_left = ast_binary_op_node(left, right, operator_name, op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse logical expression (and, or, xor) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_logical(Parser* parser) { + ASTNode* left = parser_parse_comparison(parser); + if (left == NULL) { + return NULL; + } + + /* Handle logical operators */ + while (parser_check(parser, TOKEN_KEYWORD_AND) || + parser_check(parser, TOKEN_KEYWORD_OR) || + parser_check(parser, TOKEN_KEYWORD_XOR)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_comparison(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + const char* operator_name; + switch (op->type) { + case TOKEN_KEYWORD_AND: operator_name = "and"; break; + case TOKEN_KEYWORD_OR: operator_name = "or"; break; + case TOKEN_KEYWORD_XOR: operator_name = "xor"; break; + default: operator_name = "unknown"; break; + } + + ASTNode* new_left = ast_binary_op_node(left, right, operator_name, op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + /* Handle function application */ + while (!parser_is_at_end(parser) && + (parser_peek(parser)->type == TOKEN_IDENTIFIER || + parser_peek(parser)->type == TOKEN_FUNCTION_REF || + parser_peek(parser)->type == TOKEN_NUMBER || + parser_peek(parser)->type == TOKEN_STRING || + parser_peek(parser)->type == TOKEN_LPAREN || + parser_peek(parser)->type == TOKEN_LBRACE || + parser_peek(parser)->type == TOKEN_OP_UNARY_MINUS || + parser_peek(parser)->type == TOKEN_KEYWORD_NOT) && + parser_peek(parser)->type != TOKEN_OP_PLUS && + parser_peek(parser)->type != TOKEN_OP_MINUS && + parser_peek(parser)->type != TOKEN_OP_MULTIPLY && + parser_peek(parser)->type != TOKEN_OP_DIVIDE && + parser_peek(parser)->type != TOKEN_OP_MODULO && + parser_peek(parser)->type != TOKEN_OP_POWER && + parser_peek(parser)->type != TOKEN_OP_EQUALS && + parser_peek(parser)->type != TOKEN_OP_NOT_EQUALS && + parser_peek(parser)->type != TOKEN_OP_LESS && + parser_peek(parser)->type != TOKEN_OP_LESS_EQUAL && + parser_peek(parser)->type != TOKEN_OP_GREATER && + parser_peek(parser)->type != TOKEN_OP_GREATER_EQUAL && + parser_peek(parser)->type != TOKEN_KEYWORD_AND && + parser_peek(parser)->type != TOKEN_KEYWORD_OR && + parser_peek(parser)->type != TOKEN_KEYWORD_XOR && + parser_peek(parser)->type != TOKEN_KEYWORD_WHEN && + parser_peek(parser)->type != TOKEN_KEYWORD_IS && + parser_peek(parser)->type != TOKEN_KEYWORD_THEN && + parser_peek(parser)->type != TOKEN_RPAREN && + parser_peek(parser)->type != TOKEN_RBRACE && + parser_peek(parser)->type != TOKEN_RBRACKET && + parser_peek(parser)->type != TOKEN_SEMICOLON && + parser_peek(parser)->type != TOKEN_COMMA && + parser_peek(parser)->type != TOKEN_EOF) { + + /* Collect all arguments for this function call */ + ASTNode** args = NULL; + int arg_count = 0; + + while (!parser_is_at_end(parser) && + (parser_peek(parser)->type == TOKEN_IDENTIFIER || + parser_peek(parser)->type == TOKEN_FUNCTION_REF || + parser_peek(parser)->type == TOKEN_NUMBER || + parser_peek(parser)->type == TOKEN_STRING || + parser_peek(parser)->type == TOKEN_LPAREN || + parser_peek(parser)->type == TOKEN_LBRACE || + parser_peek(parser)->type == TOKEN_OP_UNARY_MINUS || + parser_peek(parser)->type == TOKEN_KEYWORD_NOT) && + parser_peek(parser)->type != TOKEN_OP_PLUS && + parser_peek(parser)->type != TOKEN_OP_MINUS && + parser_peek(parser)->type != TOKEN_OP_MULTIPLY && + parser_peek(parser)->type != TOKEN_OP_DIVIDE && + parser_peek(parser)->type != TOKEN_OP_MODULO && + parser_peek(parser)->type != TOKEN_OP_POWER && + parser_peek(parser)->type != TOKEN_OP_EQUALS && + parser_peek(parser)->type != TOKEN_OP_NOT_EQUALS && + parser_peek(parser)->type != TOKEN_OP_LESS && + parser_peek(parser)->type != TOKEN_OP_LESS_EQUAL && + parser_peek(parser)->type != TOKEN_OP_GREATER && + parser_peek(parser)->type != TOKEN_OP_GREATER_EQUAL && + parser_peek(parser)->type != TOKEN_KEYWORD_AND && + parser_peek(parser)->type != TOKEN_KEYWORD_OR && + parser_peek(parser)->type != TOKEN_KEYWORD_XOR && + parser_peek(parser)->type != TOKEN_KEYWORD_WHEN && + parser_peek(parser)->type != TOKEN_KEYWORD_IS && + parser_peek(parser)->type != TOKEN_KEYWORD_THEN && + parser_peek(parser)->type != TOKEN_RPAREN && + parser_peek(parser)->type != TOKEN_RBRACE && + parser_peek(parser)->type != TOKEN_RBRACKET && + parser_peek(parser)->type != TOKEN_SEMICOLON && + parser_peek(parser)->type != TOKEN_COMMA && + parser_peek(parser)->type != TOKEN_EOF) { + + ASTNode* arg = parser_parse_comparison(parser); + if (arg == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(left); + return NULL; + } + + /* Add to arguments array */ + ASTNode** new_args = realloc(args, (arg_count + 1) * sizeof(ASTNode*)); + if (new_args == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(arg); + ast_destroy_node(left); + return NULL; + } + args = new_args; + args[arg_count++] = arg; + } + + /* Create function call with all arguments */ + ASTNode* new_left = ast_function_call_node(left, args, arg_count, left->line, left->column); + if (new_left == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(left); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse function composition (via) + * + * @param parser Parser instance + * @return Parsed expression node + */ +/* TODO: Re-implement composition parsing */ +/* +static ASTNode* parser_parse_composition(Parser* parser) { + ASTNode* left = parser_parse_application(parser); + if (left == NULL) { + return NULL; + } + + while (parser_check(parser, TOKEN_KEYWORD_VIA)) { + Token* op = parser_advance(parser); + ASTNode* right = parser_parse_logical(parser); + if (right == NULL) { + ast_destroy_node(left); + return NULL; + } + + ASTNode* new_left = ast_binary_op_node(left, right, "compose", op->line, op->column); + if (new_left == NULL) { + ast_destroy_node(left); + ast_destroy_node(right); + return NULL; + } + + left = new_left; + } + + return left; +} +*/ + +/** + * @brief Parse function application (juxtaposition) + * + * @param parser Parser instance + * @return Parsed expression node + */ +/** + * @brief Parse function application (juxtaposition) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_application(Parser* parser) { + ASTNode* left = parser_parse_logical(parser); + if (left == NULL) { + return NULL; + } + + /* Function application is left-associative */ + while (!parser_is_at_end(parser) && + (parser_peek(parser)->type == TOKEN_IDENTIFIER || + parser_peek(parser)->type == TOKEN_FUNCTION_REF || + parser_peek(parser)->type == TOKEN_NUMBER || + parser_peek(parser)->type == TOKEN_STRING || + parser_peek(parser)->type == TOKEN_LPAREN || + parser_peek(parser)->type == TOKEN_LBRACE || + parser_peek(parser)->type == TOKEN_OP_UNARY_MINUS || + parser_peek(parser)->type == TOKEN_KEYWORD_NOT) && + parser_peek(parser)->type != TOKEN_OP_PLUS && + parser_peek(parser)->type != TOKEN_OP_MINUS && + parser_peek(parser)->type != TOKEN_OP_MULTIPLY && + parser_peek(parser)->type != TOKEN_OP_DIVIDE && + parser_peek(parser)->type != TOKEN_OP_MODULO && + parser_peek(parser)->type != TOKEN_OP_POWER && + parser_peek(parser)->type != TOKEN_OP_EQUALS && + parser_peek(parser)->type != TOKEN_OP_NOT_EQUALS && + parser_peek(parser)->type != TOKEN_OP_LESS && + parser_peek(parser)->type != TOKEN_OP_LESS_EQUAL && + parser_peek(parser)->type != TOKEN_OP_GREATER && + parser_peek(parser)->type != TOKEN_OP_GREATER_EQUAL && + parser_peek(parser)->type != TOKEN_KEYWORD_AND && + parser_peek(parser)->type != TOKEN_KEYWORD_OR && + parser_peek(parser)->type != TOKEN_KEYWORD_XOR && + parser_peek(parser)->type != TOKEN_KEYWORD_WHEN && + parser_peek(parser)->type != TOKEN_KEYWORD_IS && + parser_peek(parser)->type != TOKEN_KEYWORD_THEN && + parser_peek(parser)->type != TOKEN_RPAREN && + parser_peek(parser)->type != TOKEN_RBRACE && + parser_peek(parser)->type != TOKEN_RBRACKET && + parser_peek(parser)->type != TOKEN_SEMICOLON && + parser_peek(parser)->type != TOKEN_COMMA && + parser_peek(parser)->type != TOKEN_EOF) { + + /* Collect all arguments for this function call */ + ASTNode** args = NULL; + int arg_count = 0; + + while (!parser_is_at_end(parser) && + (parser_peek(parser)->type == TOKEN_IDENTIFIER || + parser_peek(parser)->type == TOKEN_FUNCTION_REF || + parser_peek(parser)->type == TOKEN_NUMBER || + parser_peek(parser)->type == TOKEN_STRING || + parser_peek(parser)->type == TOKEN_LPAREN || + parser_peek(parser)->type == TOKEN_LBRACE || + parser_peek(parser)->type == TOKEN_OP_UNARY_MINUS || + parser_peek(parser)->type == TOKEN_KEYWORD_NOT) && + parser_peek(parser)->type != TOKEN_OP_PLUS && + parser_peek(parser)->type != TOKEN_OP_MINUS && + parser_peek(parser)->type != TOKEN_OP_MULTIPLY && + parser_peek(parser)->type != TOKEN_OP_DIVIDE && + parser_peek(parser)->type != TOKEN_OP_MODULO && + parser_peek(parser)->type != TOKEN_OP_POWER && + parser_peek(parser)->type != TOKEN_OP_EQUALS && + parser_peek(parser)->type != TOKEN_OP_NOT_EQUALS && + parser_peek(parser)->type != TOKEN_OP_LESS && + parser_peek(parser)->type != TOKEN_OP_LESS_EQUAL && + parser_peek(parser)->type != TOKEN_OP_GREATER && + parser_peek(parser)->type != TOKEN_OP_GREATER_EQUAL && + parser_peek(parser)->type != TOKEN_KEYWORD_AND && + parser_peek(parser)->type != TOKEN_KEYWORD_OR && + parser_peek(parser)->type != TOKEN_KEYWORD_XOR && + parser_peek(parser)->type != TOKEN_KEYWORD_WHEN && + parser_peek(parser)->type != TOKEN_KEYWORD_IS && + parser_peek(parser)->type != TOKEN_KEYWORD_THEN && + parser_peek(parser)->type != TOKEN_RPAREN && + parser_peek(parser)->type != TOKEN_RBRACE && + parser_peek(parser)->type != TOKEN_RBRACKET && + parser_peek(parser)->type != TOKEN_SEMICOLON && + parser_peek(parser)->type != TOKEN_COMMA && + parser_peek(parser)->type != TOKEN_EOF) { + + ASTNode* arg = parser_parse_logical(parser); + if (arg == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(left); + return NULL; + } + + /* Add to arguments array */ + ASTNode** new_args = realloc(args, (arg_count + 1) * sizeof(ASTNode*)); + if (new_args == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(arg); + ast_destroy_node(left); + return NULL; + } + args = new_args; + args[arg_count++] = arg; + } + + /* Create function call with all arguments */ + ASTNode* new_left = ast_function_call_node(left, args, arg_count, left->line, left->column); + if (new_left == NULL) { + /* Cleanup on error */ + for (int i = 0; i < arg_count; i++) { + ast_destroy_node(args[i]); + } + free(args); + ast_destroy_node(left); + return NULL; + } + + left = new_left; + } + + return left; +} + +/** + * @brief Parse expression (entry point) + * + * @param parser Parser instance + * @return Parsed expression node + */ +static ASTNode* parser_parse_expression(Parser* parser) { + return parser_parse_application(parser); +} + +/* ============================================================================ + * Statement Parsing + * ============================================================================ */ + +/** + * @brief Parse variable declaration + * + * @param parser Parser instance + * @return Parsed variable declaration node + */ +static ASTNode* parser_parse_variable_decl(Parser* parser) { + Token* name = parser_consume(parser, TOKEN_IDENTIFIER, "Expected variable name"); + if (name == NULL) { + return NULL; + } + + if (!parser_consume(parser, TOKEN_COLON, "Expected ':' after variable name")) { + return NULL; + } + + ASTNode* value = parser_parse_expression(parser); + if (value == NULL) { + return NULL; + } + + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + ast_destroy_node(value); + return NULL; + } + + node->type = NODE_VARIABLE_DECL; + node->line = name->line; + node->column = name->column; + node->data.variable_decl.name = strdup(name->lexeme); + node->data.variable_decl.value = value; + + return node; +} + +/** + * @brief Parse function definition + * + * @param parser Parser instance + * @return Parsed function definition node + */ +static ASTNode* parser_parse_function_def(Parser* parser) { + Token* name = parser_consume(parser, TOKEN_IDENTIFIER, "Expected function name"); + if (name == NULL) { + return NULL; + } + + if (!parser_consume(parser, TOKEN_COLON, "Expected ':' after function name")) { + return NULL; + } + + /* Parse parameters */ + ASTNode** parameters = NULL; + int param_count = 0; + + while (!parser_is_at_end(parser) && + parser_peek(parser)->type == TOKEN_IDENTIFIER) { + Token* param = parser_advance(parser); + + ASTNode** new_params = realloc(parameters, (param_count + 1) * sizeof(ASTNode*)); + if (new_params == NULL) { + for (int i = 0; i < param_count; i++) { + ast_destroy_node(parameters[i]); + } + free(parameters); + return NULL; + } + parameters = new_params; + + parameters[param_count] = ast_identifier_node(param->lexeme, param->line, param->column); + param_count++; + } + + if (!parser_consume(parser, TOKEN_ARROW, "Expected '->' after parameters")) { + for (int i = 0; i < param_count; i++) { + ast_destroy_node(parameters[i]); + } + free(parameters); + return NULL; + } + + ASTNode* body = parser_parse_expression(parser); + if (body == NULL) { + for (int i = 0; i < param_count; i++) { + ast_destroy_node(parameters[i]); + } + free(parameters); + return NULL; + } + + ASTNode* node = malloc(sizeof(ASTNode)); + if (node == NULL) { + for (int i = 0; i < param_count; i++) { + ast_destroy_node(parameters[i]); + } + free(parameters); + ast_destroy_node(body); + return NULL; + } + + node->type = NODE_FUNCTION_DEF; + node->line = name->line; + node->column = name->column; + node->data.function_def.name = strdup(name->lexeme); + node->data.function_def.parameters = parameters; + node->data.function_def.param_count = param_count; + node->data.function_def.body = body; + + return node; +} + +/** + * @brief Parse multiple statements separated by semicolons + * + * @param parser Parser instance + * @return Parsed sequence node or single statement node + */ +static ASTNode* parser_parse_statements(Parser* parser) { + if (parser_is_at_end(parser)) { + return NULL; + } + + /* Parse first statement */ + ASTNode* first_statement = parser_parse_statement(parser); + if (first_statement == NULL) { + return NULL; + } + + /* Check if there are more statements (semicolon-separated) */ + if (parser_is_at_end(parser)) { + return first_statement; /* Single statement */ + } + + Token* next_token = parser_peek(parser); + if (next_token->type != TOKEN_SEMICOLON) { + return first_statement; /* Single statement */ + } + + /* We have multiple statements, collect them */ + ASTNode** statements = malloc(10 * sizeof(ASTNode*)); /* Start with space for 10 */ + if (statements == NULL) { + ast_destroy_node(first_statement); + return NULL; + } + + int statement_count = 0; + int capacity = 10; + + /* Add first statement */ + statements[statement_count++] = first_statement; + + /* Parse remaining statements */ + while (!parser_is_at_end(parser) && + parser_peek(parser)->type == TOKEN_SEMICOLON) { + + /* Consume semicolon */ + parser_consume(parser, TOKEN_SEMICOLON, "Expected semicolon"); + + /* Skip any whitespace/comments after semicolon */ + while (!parser_is_at_end(parser) && + (parser_peek(parser)->type == TOKEN_COMMENT)) { + parser->current++; /* Skip comment */ + } + + if (parser_is_at_end(parser)) { + break; /* Trailing semicolon */ + } + + /* Parse next statement */ + ASTNode* next_statement = parser_parse_statement(parser); + if (next_statement == NULL) { + /* Error parsing statement, but continue with what we have */ + break; + } + + /* Expand array if needed */ + if (statement_count >= capacity) { + capacity *= 2; + ASTNode** new_statements = realloc(statements, capacity * sizeof(ASTNode*)); + if (new_statements == NULL) { + /* Cleanup and return what we have */ + for (int i = 0; i < statement_count; i++) { + ast_destroy_node(statements[i]); + } + free(statements); + return NULL; + } + statements = new_statements; + } + + statements[statement_count++] = next_statement; + } + + /* If we only have one statement, return it directly */ + if (statement_count == 1) { + ASTNode* result = statements[0]; + free(statements); + return result; + } + + /* Create sequence node */ + return ast_sequence_node(statements, statement_count, + first_statement->line, first_statement->column); +} + +/** + * @brief Parse statement + * + * @param parser Parser instance + * @return Parsed statement node + */ +static ASTNode* parser_parse_statement(Parser* parser) { + if (parser_is_at_end(parser)) { + return NULL; + } + + Token* token = parser_peek(parser); + + /* Check for variable declaration */ + if (token->type == TOKEN_IDENTIFIER && + parser_peek_next(parser) != NULL && + parser_peek_next(parser)->type == TOKEN_COLON) { + + /* Look ahead to see if it's a function definition */ + int save_current = parser->current; + parser->current += 2; /* skip identifier and colon */ + + bool is_function = false; + while (!parser_is_at_end(parser) && + parser_peek(parser)->type == TOKEN_IDENTIFIER) { + parser->current++; + } + + if (!parser_is_at_end(parser) && + parser_peek(parser)->type == TOKEN_ARROW) { + is_function = true; + } + + parser->current = save_current; + + if (is_function) { + return parser_parse_function_def(parser); + } else { + return parser_parse_variable_decl(parser); + } + } + + + + /* Default to expression */ + return parser_parse_expression(parser); +} + +/* ============================================================================ + * Public Parser API + * ============================================================================ */ + +/** + * @brief Parse source code into AST + * + * @param tokens Array of tokens + * @param token_count Number of tokens + * @return Root AST node, or NULL on error + */ +void* baba_yaga_parse(void** tokens, size_t token_count) { + if (tokens == NULL || token_count == 0) { + return NULL; + } + + Parser* parser = parser_create((Token**)tokens, (int)token_count); + if (parser == NULL) { + return NULL; + } + + ASTNode* result = parser_parse_statements(parser); + + if (parser->has_error) { + fprintf(stderr, "Parse error: %s\n", parser->error_message); + if (result != NULL) { + ast_destroy_node(result); + result = NULL; + } + } + + parser_destroy(parser); + return (void*)result; +} + +/** + * @brief Destroy AST + * + * @param node Root AST node + */ +void baba_yaga_destroy_ast(void* node) { + ast_destroy_node((ASTNode*)node); +} + +/** + * @brief Print AST for debugging + * + * @param node Root AST node + * @param indent Initial indentation level + */ +/* ============================================================================ + * AST Accessor Functions + * ============================================================================ */ + +NodeType baba_yaga_ast_get_type(void* node) { + if (node == NULL) { + return NODE_LITERAL; /* Default fallback */ + } + ASTNode* ast_node = (ASTNode*)node; + return ast_node->type; +} + +Value baba_yaga_ast_get_literal(void* node) { + if (node == NULL) { + return baba_yaga_value_nil(); + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_LITERAL) { + return baba_yaga_value_copy(&ast_node->data.literal); + } + return baba_yaga_value_nil(); +} + +const char* baba_yaga_ast_get_identifier(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_IDENTIFIER) { + return ast_node->data.identifier; + } + return NULL; +} + +void* baba_yaga_ast_get_function_call_func(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_CALL) { + return ast_node->data.function_call.function; + } + return NULL; +} + +int baba_yaga_ast_get_function_call_arg_count(void* node) { + if (node == NULL) { + return 0; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_CALL) { + return ast_node->data.function_call.arg_count; + } + return 0; +} + +void* baba_yaga_ast_get_function_call_arg(void* node, int index) { + if (node == NULL || index < 0) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_CALL && + index < ast_node->data.function_call.arg_count) { + return ast_node->data.function_call.arguments[index]; + } + return NULL; +} + +void* baba_yaga_ast_get_binary_op_left(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_BINARY_OP) { + return ast_node->data.binary.left; + } + return NULL; +} + +void* baba_yaga_ast_get_binary_op_right(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_BINARY_OP) { + return ast_node->data.binary.right; + } + return NULL; +} + +const char* baba_yaga_ast_get_binary_op_operator(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_BINARY_OP) { + return ast_node->data.binary.operator; + } + return NULL; +} + +void* baba_yaga_ast_get_unary_op_operand(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_UNARY_OP) { + return ast_node->data.unary.operand; + } + return NULL; +} + +const char* baba_yaga_ast_get_unary_op_operator(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_UNARY_OP) { + return ast_node->data.unary.operator; + } + return NULL; +} + +const char* baba_yaga_ast_get_function_def_name(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_DEF) { + return ast_node->data.function_def.name; + } + return NULL; +} + +int baba_yaga_ast_get_function_def_param_count(void* node) { + if (node == NULL) { + return 0; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_DEF) { + return ast_node->data.function_def.param_count; + } + return 0; +} + +void* baba_yaga_ast_get_function_def_param(void* node, int index) { + if (node == NULL || index < 0) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_DEF) { + if (index < ast_node->data.function_def.param_count) { + return ast_node->data.function_def.parameters[index]; + } + } + return NULL; +} + +void* baba_yaga_ast_get_function_def_body(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_FUNCTION_DEF) { + return ast_node->data.function_def.body; + } + return NULL; +} + +const char* baba_yaga_ast_get_variable_decl_name(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_VARIABLE_DECL) { + return ast_node->data.variable_decl.name; + } + return NULL; +} + +void* baba_yaga_ast_get_variable_decl_value(void* node) { + if (node == NULL) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_VARIABLE_DECL) { + return ast_node->data.variable_decl.value; + } + return NULL; +} + +int baba_yaga_ast_get_sequence_statement_count(void* node) { + if (node == NULL) { + return 0; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_SEQUENCE) { + return ast_node->data.sequence.statement_count; + } + return 0; +} + +void* baba_yaga_ast_get_sequence_statement(void* node, int index) { + if (node == NULL || index < 0) { + return NULL; + } + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type == NODE_SEQUENCE) { + if (index < ast_node->data.sequence.statement_count) { + return ast_node->data.sequence.statements[index]; + } + } + return NULL; +} + +void* baba_yaga_ast_get_when_expr_test(void* node) { + if (node == NULL) { + return NULL; + } + + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type != NODE_WHEN_EXPR) { + return NULL; + } + + return ast_node->data.when_expr.test; +} + +int baba_yaga_ast_get_when_expr_pattern_count(void* node) { + if (node == NULL) { + return 0; + } + + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type != NODE_WHEN_EXPR) { + return 0; + } + + return ast_node->data.when_expr.pattern_count; +} + +void* baba_yaga_ast_get_when_expr_pattern(void* node, int index) { + if (node == NULL) { + return NULL; + } + + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type != NODE_WHEN_EXPR) { + return NULL; + } + + if (index >= 0 && index < ast_node->data.when_expr.pattern_count) { + return ast_node->data.when_expr.patterns[index]; + } + return NULL; +} + +void* baba_yaga_ast_get_when_pattern_test(void* node) { + if (node == NULL) { + return NULL; + } + + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type != NODE_WHEN_PATTERN) { + return NULL; + } + + return ast_node->data.when_pattern.test; +} + +void* baba_yaga_ast_get_when_pattern_result(void* node) { + if (node == NULL) { + return NULL; + } + + ASTNode* ast_node = (ASTNode*)node; + if (ast_node->type != NODE_WHEN_PATTERN) { + return NULL; + } + + return ast_node->data.when_pattern.result; +} + +void baba_yaga_print_ast(void* node, int indent) { + if (node == NULL) { + return; + } + + ASTNode* ast_node = (ASTNode*)node; + + /* Print indentation */ + for (int i = 0; i < indent; i++) { + printf(" "); + } + + /* Print node type */ + printf("%s", node_type_name(ast_node->type)); + + /* Print node-specific information */ + switch (ast_node->type) { + case NODE_LITERAL: + if (ast_node->data.literal.type == VAL_NUMBER) { + printf(": %g", ast_node->data.literal.data.number); + } else if (ast_node->data.literal.type == VAL_STRING) { + printf(": \"%s\"", ast_node->data.literal.data.string); + } else if (ast_node->data.literal.type == VAL_BOOLEAN) { + printf(": %s", ast_node->data.literal.data.boolean ? "true" : "false"); + } + break; + case NODE_IDENTIFIER: + printf(": %s", ast_node->data.identifier); + break; + case NODE_FUNCTION_CALL: + printf(" (args: %d)", ast_node->data.function_call.arg_count); + break; + case NODE_FUNCTION_DEF: + printf(": %s (params: %d)", ast_node->data.function_def.name, ast_node->data.function_def.param_count); + break; + case NODE_VARIABLE_DECL: + printf(": %s", ast_node->data.variable_decl.name); + break; + case NODE_SEQUENCE: + printf(" (statements: %d)", ast_node->data.sequence.statement_count); + break; + default: + break; + } + + printf(" (line %d, col %d)\n", ast_node->line, ast_node->column); + + /* Print children */ + switch (ast_node->type) { + case NODE_FUNCTION_CALL: + baba_yaga_print_ast(ast_node->data.function_call.function, indent + 1); + for (int i = 0; i < ast_node->data.function_call.arg_count; i++) { + baba_yaga_print_ast(ast_node->data.function_call.arguments[i], indent + 1); + } + break; + case NODE_FUNCTION_DEF: + for (int i = 0; i < ast_node->data.function_def.param_count; i++) { + baba_yaga_print_ast(ast_node->data.function_def.parameters[i], indent + 1); + } + baba_yaga_print_ast(ast_node->data.function_def.body, indent + 1); + break; + case NODE_VARIABLE_DECL: + baba_yaga_print_ast(ast_node->data.variable_decl.value, indent + 1); + break; + case NODE_SEQUENCE: + for (int i = 0; i < ast_node->data.sequence.statement_count; i++) { + baba_yaga_print_ast(ast_node->data.sequence.statements[i], indent + 1); + } + break; + default: + break; + } +} + +/** + * @brief Parse when expression + * + * @param parser Parser instance + * @return Parsed when expression node + */ +static ASTNode* parser_parse_when_expression(Parser* parser) { + /* Consume 'when' keyword */ + Token* when_token = parser_consume(parser, TOKEN_KEYWORD_WHEN, "Expected 'when'"); + if (when_token == NULL) { + return NULL; + } + + /* Parse test expression */ + ASTNode* test = parser_parse_expression(parser); + if (test == NULL) { + return NULL; + } + + /* Consume 'is' keyword */ + Token* is_token = parser_consume(parser, TOKEN_KEYWORD_IS, "Expected 'is' after test expression"); + if (is_token == NULL) { + ast_destroy_node(test); + return NULL; + } + + /* Parse patterns */ + ASTNode** patterns = NULL; + int pattern_count = 0; + int capacity = 5; /* Start with space for 5 patterns */ + + patterns = malloc(capacity * sizeof(ASTNode*)); + if (patterns == NULL) { + ast_destroy_node(test); + return NULL; + } + + /* Parse first pattern */ + ASTNode* pattern = parser_parse_when_pattern(parser); + if (pattern == NULL) { + free(patterns); + ast_destroy_node(test); + return NULL; + } + + patterns[pattern_count++] = pattern; + + /* Parse additional patterns */ + while (!parser_is_at_end(parser)) { + /* Parse next pattern */ + ASTNode* next_pattern = parser_parse_when_pattern(parser); + if (next_pattern == NULL) { + break; /* Error parsing pattern, but continue with what we have */ + } + + /* Expand array if needed */ + if (pattern_count >= capacity) { + capacity *= 2; + ASTNode** new_patterns = realloc(patterns, capacity * sizeof(ASTNode*)); + if (new_patterns == NULL) { + /* Cleanup and return what we have */ + for (int i = 0; i < pattern_count; i++) { + ast_destroy_node(patterns[i]); + } + free(patterns); + ast_destroy_node(test); + return NULL; + } + patterns = new_patterns; + } + + patterns[pattern_count++] = next_pattern; + } + + /* Create when expression node */ + return ast_when_expr_node(test, patterns, pattern_count, + when_token->line, when_token->column); +} + +/** + * @brief Parse when pattern + * + * @param parser Parser instance + * @return Parsed when pattern node + */ +static ASTNode* parser_parse_when_pattern(Parser* parser) { + /* Parse pattern test expression */ + ASTNode* pattern_test = parser_parse_expression(parser); + if (pattern_test == NULL) { + return NULL; + } + + /* Consume 'then' keyword */ + Token* then_token = parser_consume(parser, TOKEN_KEYWORD_THEN, "Expected 'then' after pattern"); + if (then_token == NULL) { + ast_destroy_node(pattern_test); + return NULL; + } + + /* Parse result expression */ + ASTNode* result = parser_parse_expression(parser); + if (result == NULL) { + ast_destroy_node(pattern_test); + return NULL; + } + + /* Create when pattern node */ + return ast_when_pattern_node(pattern_test, result, + then_token->line, then_token->column); +} + +/* Helper function to get node type name */ +static const char* node_type_name(NodeType type) { + switch (type) { + case NODE_LITERAL: return "LITERAL"; + case NODE_IDENTIFIER: return "IDENTIFIER"; + case NODE_BINARY_OP: return "BINARY_OP"; + case NODE_UNARY_OP: return "UNARY_OP"; + case NODE_FUNCTION_CALL: return "FUNCTION_CALL"; + case NODE_FUNCTION_DEF: return "FUNCTION_DEF"; + case NODE_VARIABLE_DECL: return "VARIABLE_DECL"; + case NODE_WHEN_EXPR: return "WHEN_EXPR"; + case NODE_WHEN_PATTERN: return "WHEN_PATTERN"; + case NODE_TABLE: return "TABLE"; + case NODE_TABLE_ACCESS: return "TABLE_ACCESS"; + case NODE_IO_OPERATION: return "IO_OPERATION"; + case NODE_SEQUENCE: return "SEQUENCE"; + default: return "UNKNOWN"; + } +} |