view src/expr.c @ 67:d5fe306f1ab1

Fixed numerous bugs in macro handling
author lost
date Mon, 05 Jan 2009 05:40:33 +0000
parents 73423b66e511
children 2fe5fd7d65a3
line wrap: on
line source

/*
expr.c
Copyright © 2008 William Astle

This file is part of LWASM.

LWASM is free software: you can redistribute it and/or modify it under the
terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
more details.

You should have received a copy of the GNU General Public License along with
this program. If not, see <http://www.gnu.org/licenses/>.
*/

/*
This file contains the actual expression evaluator
*/

#define __expr_c_seen__

#include <ctype.h>
#include <stdlib.h>
#include <string.h>

#include "expr.h"
#include "util.h"
#include "lwasm.h"

lwasm_expr_stack_t *lwasm_expr_stack_create(void)
{
	lwasm_expr_stack_t *s;
	
	s = lwasm_alloc(sizeof(lwasm_expr_stack_t));
	s -> head = NULL;
	s -> tail = NULL;
	return s;
}

void lwasm_expr_stack_free(lwasm_expr_stack_t *s)
{
	while (s -> head)
	{
		s -> tail = s -> head;
		s -> head = s -> head -> next;
		lwasm_expr_term_free(s -> tail -> term);
		lwasm_free(s -> tail);
	}
	lwasm_free(s);
}

void lwasm_expr_term_free(lwasm_expr_term_t *t)
{
	if (t)
	{
		if (t -> term_type == LWASM_TERM_SYM)
			lwasm_free(t -> symbol);
		lwasm_free(t);
	}
}

lwasm_expr_term_t *lwasm_expr_term_create_oper(int oper)
{
	lwasm_expr_term_t *t;

	debug_message(10, "Creating operator term: %d", oper);
	
	t = lwasm_alloc(sizeof(lwasm_expr_term_t));
	t -> term_type = LWASM_TERM_OPER;
	t -> value = oper;
	return t;
}

lwasm_expr_term_t *lwasm_expr_term_create_int(int val)
{
	lwasm_expr_term_t *t;
	debug_message(10, "Creating integer term: %d", val);
	
	t = lwasm_alloc(sizeof(lwasm_expr_term_t));
	t -> term_type = LWASM_TERM_INT;
	t -> value = val;
	return t;
}

lwasm_expr_term_t *lwasm_expr_term_create_sym(char *sym)
{
	lwasm_expr_term_t *t;
	
	debug_message(10, "Creating symbol term: %s", sym);
	
	t = lwasm_alloc(sizeof(lwasm_expr_term_t));
	t -> term_type = LWASM_TERM_SYM;
	t -> symbol = lwasm_strdup(sym);
	return t;
}

lwasm_expr_term_t *lwasm_expr_term_dup(lwasm_expr_term_t *t)
{
	switch (t -> term_type)
	{
	case LWASM_TERM_INT:
		return lwasm_expr_term_create_int(t -> value);
		
	case LWASM_TERM_OPER:
		return lwasm_expr_term_create_oper(t -> value);
		
	case LWASM_TERM_SYM:
		return lwasm_expr_term_create_sym(t -> symbol);
		
	default:
		debug_message(0, "lwasm_expr_term_dup(): invalid term type %d", t -> term_type);
		exit(1);
	}
// can't get here
}

void lwasm_expr_stack_push(lwasm_expr_stack_t *s, lwasm_expr_term_t *t)
{
	lwasm_expr_stack_node_t *n;

	if (!s)
	{
		debug_message(0, "lwasm_expr_stack_push(): invalid stack pointer");
		exit(1);
	}
	
	n = lwasm_alloc(sizeof(lwasm_expr_stack_node_t));
	n -> next = NULL;
	n -> prev = s -> tail;
	n -> term = lwasm_expr_term_dup(t);
	
	if (s -> head)
	{
		s -> tail -> next = n;
		s -> tail = n;
	}
	else
	{
		s -> head = n;
		s -> tail = n;
	}
}

lwasm_expr_term_t *lwasm_expr_stack_pop(lwasm_expr_stack_t *s)
{
	lwasm_expr_term_t *t;
	lwasm_expr_stack_node_t *n;
	
	if (!(s -> tail))
		return NULL;
	
	n = s -> tail;
	s -> tail = n -> prev;
	if (!(n -> prev))
	{
		s -> head = NULL;
	}
	
	t = n -> term;
	n -> term = NULL;
	
	lwasm_free(n);
	
	return t;
}

// the following two functions are co-routines which actually parse
// an infix expression onto the expression stack, each returns -1
// if an error is encountered

/*
parse a term and push it onto the stack

this function handles unary prefix operators (-, +, .not., .com.)
as well as ()
*/
int lwasm_expr_parse_term(lwasm_expr_stack_t *s, const char **p)
{
	lwasm_expr_term_t *t;
	debug_message(2, "Expression string %s", *p);

eval_next:
	if (!**p || isspace(**p) || **p == ')' || **p == ']')
		return -1;
	if (**p == '(')
	{
		debug_message(3, "Starting paren");
		(*p)++;
		lwasm_expr_parse_expr(s, p, 0);
		if (**p != ')')
			return -1;
		(*p)++;
		return 0;
	}
	
	if (**p == '+')
	{
		debug_message(3, "Unary +");
		(*p)++;
		goto eval_next;
	}
	
	if (**p == '-')
	{
		// parse expression following "-"
		(*p)++;
		if (lwasm_expr_parse_expr(s, p, 200) < 0)
			return -1;
		t = lwasm_expr_term_create_oper(LWASM_OPER_NEG);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	
	if (**p == '^')
	{
		// parse expression following "^"
		(*p)++;
		if (lwasm_expr_parse_expr(s, p, 200) < 0)
			return -1;
		t = lwasm_expr_term_create_oper(LWASM_OPER_COM);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	
	/*
		we have an actual term here so evaluate it
		
		it could be one of the following:
		
		1. a decimal constant
		2. a hexadecimal constant
		3. an octal constant
		4. a binary constant
		5. a symbol reference
		6. the "current" instruction address (*)
		7. the "current" data address (.)
		8. a "back reference" (<)
		9. a "forward reference" (>)
		
		items 6 through 9 are stored as symbol references
		
		(a . followed by a . or a alpha char or number is a symbol)
	*/
	if (**p == '*'
		|| (
			**p == '.' 
			&& (*p)[1] != '.' 
			&& !((*p)[1] >= 'A' && (*p)[1] <= 'Z') 
			&& !((*p)[1] >= 'a' && (*p)[1] <= 'z') 
			&& !((*p)[1] >= '0' && (*p)[1] <= '9')
			)
		|| **p == '<'
		|| **p == '>')
	{
		char tstr[2];
		tstr[0] = **p;
		tstr[1] = '\0';
		t = lwasm_expr_term_create_sym(tstr);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		(*p)++;
		return 0;
	}
	
	/*
		- a symbol will be a string of characters introduced by a letter, ".",
		  "_" but NOT a number
		- a decimal constant will consist of only digits, optionally prefixed
		  with "&"
		- a binary constant will consist of only 0s and 1s either prefixed with %
		  or suffixed with "B"
		- a hex constant will consist of 0-9A-F either prefixed with $ or
		  suffixed with "H"; a hex number starting with A-F must be prefixed
		  with $ or start with 0 and end with H
		- an octal constant will consist of 0-7 either prefixed with @ or
		  suffixed with "O" or "Q"
		- an ascii constant will be a single character prefixed with a '
		- a double ascii constant will be two characters prefixed with a "
		
	*/
	if (**p == '"')
	{
		// double ascii constant
		int val;
		(*p)++;
		if (!**p)
			return -1;
		if (!*((*p)+1))
			return -1;
		val = **p << 8 | *((*p) + 1);
		(*p) += 2;
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	else if (**p == '\'')
	{
		// single ascii constant
		int val;
		(*p)++;
		debug_message(3, "Single ascii character constant '%c'", **p);
		if (!**p)
			return -1;
		val = **p;
		(*p)++;
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	else if (**p == '&')
	{
		// decimal constant
		int val = 0;
		
		(*p)++;
		while (**p && strchr("0123456789", **p))
		{
			val = val * 10 + (**p - '0');
			(*p)++;
		}
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	else if (**p == '%')
	{
		// binary constant
		int val = 0;
		
		(*p)++;
		while (**p == '0' || **p == '1')
		{
			val = val * 2 + (**p - '0');
			(*p)++;
		}
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	else if (**p == '$')
	{
		// hexadecimal constant
		int val = 0, val2;
		
		(*p)++;
		debug_message(3, "Found prefix hex constant: %s", *p);
		while (**p && strchr("0123456789ABCDEFabcdef", **p))
		{
			val2 = toupper(**p) - '0';
			if (val2 > 9)
				val2 -= 7;
			debug_message(3, "Got char: %c (%d)", **p, val2);
			val = val * 16 + val2;
			(*p)++;
		}
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	// an @ followed by a digit is an octal number
	// but if it's followed by anything else, it is a symbol
	else if (**p == '@' && isdigit(*(*p + 1)))
	{
		// octal constant
		int val = 0;
		
		(*p)++;
		while (**p && strchr("01234567", **p))
		{
			val = val * 8 + (**p - '0');
			(*p)++;
		}
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	
	// symbol or bare decimal or suffix identified constant here
	// all numbers will start with a digit at this point
	if (**p < '0' || **p > '9')
	{
		int l = 0;
		char *sb;
		
		// evaluate a symbol here
		static const char *symchars = "_.$@?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
		while ((*p)[l] && strchr(symchars, (*p)[l]))
			l++;

		if (l == 0)
			return -1;

		sb = lwasm_alloc(l + 1);
		sb[l] = '\0';
		memcpy(sb, *p, l);
		t = lwasm_expr_term_create_sym(sb);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		(*p) += l;
		debug_message(3, "Symbol: '%s'; (%s)", sb, *p);
		lwasm_free(sb);
		return 0;
	}
	
	if (!**p)
		return -1;
	
	// evaluate a suffix based constant
	{
		int decval = 0, binval = 0, hexval = 0, octval = 0;
		int valtype = 15;	// 1 = bin, 2 = oct, 4 = dec, 8 = hex
		int bindone = 0;
		int val;
		int dval;
		
		while (1)
		{
			if (!**p || !strchr("0123456789ABCDEFabcdefqhoQHO", **p))
			{
				// we can legally have bin or decimal here
				if (bindone)
				{
					// we just finished a binary value
					val = binval;
					break;
				}
				else if (valtype & 4)
				{
					// otherwise we must be decimal (if we're still allowed one)
					val = decval;
					debug_message(3, "End of decimal value");
					break;
				}
				else
				{
					// bad value
					return -1;
				}
			}
			
			dval = toupper(**p);
			(*p)++;
			
			if (bindone)
			{
				// any characters past "B" means it is not binary
				bindone = 0;
				valtype &= 14;
			}
			
			switch (dval)
			{
			case 'Q':
			case 'O':
				if (valtype & 2)
				{
					val = octval;
					valtype = -1;
					break;
				}
				else
				{
					// not a valid octal value
					return -1;
				}
				/* can't get here */

			case 'H':
				if (valtype & 8)
				{
					val = hexval;
					valtype = -1;
					break;
				}
				else
				{
					// not a valid hex number
					return -1;
				}
				/* can't get here */

			case 'B':
				// this is a bit of a sticky one since B is a legit hex
				// digit so this may or may not be the end of the number
				// so we fall through to the digit case

				if (valtype & 1)
				{
					// could still be binary
					bindone = 1;
					valtype = 9;	// hex and binary
				}
				/* fall through intentional */
				
			default:
				// digit
				dval -= '0';
				if (dval > 9)
					dval -= 7;
				debug_message(3, "Got digit: %d", dval);
//				if (dval > 1)
//					valtype &= 14;
//				if (dval > 7)
//					valtype &= 12;
//				if (dval > 9)
//					valtype &= 8;

				if (valtype & 8)
				{
					hexval = hexval * 16 + dval;
				}
				if (valtype & 4)
				{
					if (dval > 9)
						valtype &= 11;
					else
						decval = decval * 10 + dval;
				}
				if (valtype & 2)
				{
					if (dval > 7)
						valtype &= 13;
					else
						octval = octval * 8 + dval;
				}
				if (valtype & 1)
				{
					if (dval > 1)
						valtype &= 14;
					else
						binval = binval * 2 + dval;
				}
			}
			// break out if we have a return value
			if (valtype == -1)
				break;
			// return if no more valid possibilities!
			if (valtype == 0)
				return -1;
			val = decval;	// in case we fall through
		}
		
		// we get here when we have a value to return
		t = lwasm_expr_term_create_int(val);
		lwasm_expr_stack_push(s, t);
		lwasm_expr_term_free(t);
		return 0;
	}
	/* can't get here */
}

// parse an expression and push the result onto the stack
// if an operator of lower precedence than the value of "prec" is found,
int lwasm_expr_parse_expr(lwasm_expr_stack_t *s, const char **p, int prec)
{
	static const struct operinfo
	{
		int opernum;
		char *operstr;
		int operprec;
	} operators[] =
	{
		{ LWASM_OPER_PLUS, "+", 100 },
		{ LWASM_OPER_MINUS, "-", 100 },
		{ LWASM_OPER_TIMES, "*", 150 },
		{ LWASM_OPER_DIVIDE, "/", 150 },
		{ LWASM_OPER_MOD, "%", 150 },
		{ LWASM_OPER_INTDIV, "\\", 150 },
		
		{ LWASM_OPER_NONE, "", 0 }
	};	
	int opern, i;
	lwasm_expr_term_t *operterm;
	
	// return if we are at the end of the expression or a subexpression
	if (!**p || isspace(**p) || **p == ')' || **p == ',' || **p == ']')
		return 0;
	
	if (lwasm_expr_parse_term(s, p) < 0)
		return -1;

eval_next:
	if (!**p || isspace(**p) || **p == ')' || **p == ',' || **p == ']')
		return 0;
	
	// expecting an operator here
	for (opern = 0; operators[opern].opernum != LWASM_OPER_NONE; opern++)
	{
		for (i = 0; (*p)[i] && operators[opern].operstr[i] && (*p[i] == operators[opern].operstr[i]); i++)
			/* do nothing */ ;
		if (operators[opern].operstr[i] == '\0')
			break;
	}
	if (operators[opern].opernum == LWASM_OPER_NONE)
	{
		// unrecognized operator
		return -1;
	}
	
	// the operator number in question is in opern; i is the length of the
	// operator string
	
	// logic:
	// if the precedence of this operation is <= to the "prec" flag,
	// we simply return without advancing the input pointer; the operator
	// will be evaluated again in the enclosing function call
	if (operators[opern].operprec <= prec)
		return 0;
	
	// logic:
	// we have a higher precedence operator here so we will advance the
	// input pointer to the next term and let the expression evaluator
	// loose on it after which time we will push our operator onto the
	// stack and then go on with the expression evaluation
	(*p) += i;	// advance input pointer
	
	// evaluate next expression(s) of higher precedence
	if (lwasm_expr_parse_expr(s, p, operators[opern].operprec) < 0)
		return -1;
	
	operterm = lwasm_expr_term_create_oper(operators[opern].opernum);
	lwasm_expr_stack_push(s, operterm);
	lwasm_expr_term_free(operterm);
	
	// return if we are at the end of the expression or a subexpression
	if (!**p || isspace(**p) || **p == ')')
		return 0;

	// continue evaluating
	goto eval_next;	 	
}

/*
actually evaluate an expression

This happens in two stages. The first stage merely parses the expression into
a lwasm_expr_stack_t * which is then evaluated as much as possible before the
result is returned.

Returns NULL on a parse error or otherwise invalid expression. *outp will
contain the pointer to the next character after the expression if and only
if there is no error. In the case of an error, *outp is undefined.
*/
lwasm_expr_stack_t *lwasm_expr_eval(const char *inp, const char **outp, int (*sfunc)(char *sym, void *state, int *val), void *state)
{
	lwasm_expr_stack_t *s;
	const char *p;
	int rval;
		
	// actually parse the expression
	p = inp;
	s = lwasm_expr_stack_create();
	
	rval = lwasm_expr_parse_expr(s, &p, 0);
	if (rval < 0)
		goto cleanup_error;
	
	// save end of expression
	if (outp)
		(*outp) = p;
	
	// return potentially partial expression
	if (lwasm_expr_reval(s, sfunc, state) < 0)
		goto cleanup_error;

	if (lwasm_expr_is_constant(s))
		debug_message(3, "Constant expression evaluates to: %d", lwasm_expr_get_value(s));
	
	return s;

cleanup_error:
	lwasm_expr_stack_free(s);
	return NULL;
}

/*
take an expression stack s and scan for operations that can be completed

return -1 on error, 0 on no error

possible errors are: division by zero or unknown operator

theory of operation:

scan the stack for an operator which has two constants preceding it (binary)
or 1 constant preceding it (unary) and if found, perform the calculation
and replace the operator and its operands with the result

repeat the scan until no futher simplications are found or if there are no
further operators or only a single term remains

*/
int lwasm_expr_reval(lwasm_expr_stack_t *s, int (*sfunc)(char *sym, void *state, int *val), void *state)
{
	lwasm_expr_stack_node_t *n;
	int sval;
	
	// resolve symbols
	// symbols that do not resolve to a constant are left alone
	for (n = s -> head; n; n = n -> next)
	{
		if (n -> term -> term_type == LWASM_TERM_SYM)
		{
			if (sfunc(n -> term -> symbol, state, &sval) == 0)
			{
				n -> term -> term_type = LWASM_TERM_INT;
				n -> term -> value = sval;
				lwasm_free(n -> term -> symbol);
				n -> term -> symbol = NULL;
			}
		}
	}

next_iter:	
	// a single term
	if (s -> head == s -> tail)
		return 0;
	
	// search for an operator
	for (n = s -> head; n; n = n -> next)
	{
		if (n -> term -> term_type == LWASM_TERM_OPER)
		{
			if (n -> term -> value == LWASM_OPER_NEG
				|| n -> term -> value == LWASM_OPER_COM
				)
			{
				// unary operator
				if (n -> prev && n -> prev -> term -> term_type == LWASM_TERM_INT)
				{
					// a unary operator we can resolve
					// we do the op then remove the term "n" is pointing at
					if (n -> term -> value == LWASM_OPER_NEG)
					{
						n -> prev -> term -> value = -(n -> prev -> term -> value);
					}
					else if (n -> term -> value == LWASM_OPER_COM)
					{
						n -> prev -> term -> value = ~(n -> prev -> term -> value);
					}
					n -> prev -> next = n -> next;
					if (n -> next)
						n -> next -> prev = n -> prev;
					else
						s -> tail = n -> prev;	
					
					lwasm_expr_term_free(n -> term);
					lwasm_free(n);
					break;
				}
			}
			else
			{
				// binary operator
				if (n -> prev && n -> prev -> prev && n -> prev -> term -> term_type == LWASM_TERM_INT && n -> prev -> prev -> term -> term_type == LWASM_TERM_INT)
				{
					// a binary operator we can resolve
					switch (n -> term -> value)
					{
					case LWASM_OPER_PLUS:
						n -> prev -> prev -> term -> value += n -> prev -> term -> value;
						break;

					case LWASM_OPER_MINUS:
						n -> prev -> prev -> term -> value -= n -> prev -> term -> value;
						break;

					case LWASM_OPER_TIMES:
						n -> prev -> prev -> term -> value *= n -> prev -> term -> value;
						break;

					case LWASM_OPER_DIVIDE:
						if (n -> prev -> term -> value == 0)
							return -1;
						n -> prev -> prev -> term -> value /= n -> prev -> term -> value;
						break;

					case LWASM_OPER_MOD:
						if (n -> prev -> term -> value == 0)
							return -1;
						n -> prev -> prev -> term -> value %= n -> prev -> term -> value;
						break;

					case LWASM_OPER_INTDIV:
						if (n -> prev -> term -> value == 0)
							return -1;
						n -> prev -> prev -> term -> value /= n -> prev -> term -> value;
						break;

					case LWASM_OPER_BWAND:
						n -> prev -> prev -> term -> value &= n -> prev -> term -> value;
						break;

					case LWASM_OPER_BWOR:
						n -> prev -> prev -> term -> value |= n -> prev -> term -> value;
						break;

					case LWASM_OPER_BWXOR:
						n -> prev -> prev -> term -> value ^= n -> prev -> term -> value;
						break;

					case LWASM_OPER_AND:
						n -> prev -> prev -> term -> value = (n -> prev -> term -> value && n -> prev -> prev -> term -> value) ? 1 : 0;
						break;

					case LWASM_OPER_OR:
						n -> prev -> prev -> term -> value = (n -> prev -> term -> value || n -> prev -> prev -> term -> value) ? 1 : 0;
						break;

					default:
						// return error if unknown operator!
						return -1;
					}

					// now remove the two unneeded entries from the stack
					n -> prev -> prev -> next = n -> next;
					if (n -> next)
						n -> next -> prev = n -> prev -> prev;
					else
						s -> tail = n -> prev -> prev;	
					
					lwasm_expr_term_free(n -> term);
					lwasm_expr_term_free(n -> prev -> term);
					lwasm_free(n -> prev);
					lwasm_free(n);
					break;
				}
			}
		}
	}
	// note for the terminally confused about dynamic memory and pointers:
	// n will not be NULL even after the lwasm_free calls above so
	// this test will still work (n will be a dangling pointer)
	// (n will only be NULL if we didn't find any operators to simplify)
	if (n)
		goto next_iter;
	
	return 0;
}