290 lines
6.4 KiB
Go
290 lines
6.4 KiB
Go
package expr
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"text/scanner"
|
|
"unicode"
|
|
)
|
|
|
|
const (
|
|
TokenEOF = -(iota + 1)
|
|
TokenIdent
|
|
TokenInt
|
|
TokenFloat
|
|
TokenOperator
|
|
)
|
|
|
|
type lexer struct {
|
|
scan scanner.Scanner
|
|
token rune
|
|
text string
|
|
}
|
|
|
|
func (lex *lexer) getToken() rune {
|
|
return lex.token
|
|
}
|
|
|
|
func (lex *lexer) getText() string {
|
|
return lex.text
|
|
}
|
|
|
|
func (lex *lexer) next() {
|
|
token := lex.scan.Scan()
|
|
text := lex.scan.TokenText()
|
|
switch token {
|
|
case scanner.EOF:
|
|
lex.token = TokenEOF
|
|
lex.text = text
|
|
case scanner.Ident:
|
|
lex.token = TokenIdent
|
|
lex.text = text
|
|
case scanner.Int:
|
|
lex.token = TokenInt
|
|
lex.text = text
|
|
case scanner.Float:
|
|
lex.token = TokenFloat
|
|
lex.text = text
|
|
case '+', '-', '*', '/', '%', '~':
|
|
lex.token = TokenOperator
|
|
lex.text = text
|
|
case '&', '|', '=':
|
|
var buffer bytes.Buffer
|
|
lex.token = TokenOperator
|
|
buffer.WriteRune(token)
|
|
next := lex.scan.Peek()
|
|
if next == token {
|
|
buffer.WriteRune(next)
|
|
lex.scan.Scan()
|
|
}
|
|
lex.text = buffer.String()
|
|
case '>', '<', '!':
|
|
var buffer bytes.Buffer
|
|
lex.token = TokenOperator
|
|
buffer.WriteRune(token)
|
|
next := lex.scan.Peek()
|
|
if next == '=' {
|
|
buffer.WriteRune(next)
|
|
lex.scan.Scan()
|
|
}
|
|
lex.text = buffer.String()
|
|
default:
|
|
if token >= 0 {
|
|
lex.token = token
|
|
lex.text = text
|
|
} else {
|
|
msg := fmt.Sprintf("got unknown token:%q, text:%s", lex.token, lex.text)
|
|
panic(lexPanic(msg))
|
|
}
|
|
}
|
|
//fmt.Printf("token:%d, text:%s\n", lex.token, lex.text)
|
|
}
|
|
|
|
type lexPanic string
|
|
|
|
// describe returns a string describing the current token, for use in errors.
|
|
func (lex *lexer) describe() string {
|
|
switch lex.token {
|
|
case TokenEOF:
|
|
return "end of file"
|
|
case TokenIdent:
|
|
return fmt.Sprintf("identifier %s", lex.getText())
|
|
case TokenInt, TokenFloat:
|
|
return fmt.Sprintf("number %s", lex.getText())
|
|
}
|
|
return fmt.Sprintf("%q", rune(lex.getToken())) // any other rune
|
|
}
|
|
|
|
func precedence(token rune, text string) int {
|
|
if token == TokenOperator {
|
|
switch text {
|
|
case "~", "!":
|
|
return 9
|
|
case "*", "/", "%":
|
|
return 8
|
|
case "+", "-":
|
|
return 7
|
|
case ">", ">=", "<", "<=":
|
|
return 6
|
|
case "!=", "==", "=":
|
|
return 5
|
|
case "&":
|
|
return 4
|
|
case "|":
|
|
return 3
|
|
case "&&":
|
|
return 2
|
|
case "||":
|
|
return 1
|
|
default:
|
|
msg := fmt.Sprintf("unknown operator:%s", text)
|
|
panic(lexPanic(msg))
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// ---- parser ----
|
|
type ExpressionParser struct {
|
|
expression Expr
|
|
variable map[string]struct{}
|
|
}
|
|
|
|
func NewExpressionParser() *ExpressionParser {
|
|
return &ExpressionParser{
|
|
expression: nil,
|
|
variable: make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
// Parse parses the input string as an arithmetic expression.
|
|
//
|
|
// expr = num a literal number, e.g., 3.14159
|
|
// | id a variable name, e.g., x
|
|
// | id '(' expr ',' ... ')' a function call
|
|
// | '-' expr a unary operator ( + - ! )
|
|
// | expr '+' expr a binary operator ( + - * / && & || | == )
|
|
//
|
|
func (parser *ExpressionParser) Parse(input string) (err error) {
|
|
defer func() {
|
|
switch x := recover().(type) {
|
|
case nil:
|
|
// no panic
|
|
case lexPanic:
|
|
err = fmt.Errorf("%s", x)
|
|
default:
|
|
// unexpected panic: resume state of panic.
|
|
panic(x)
|
|
}
|
|
}()
|
|
lex := new(lexer)
|
|
lex.scan.Init(strings.NewReader(input))
|
|
lex.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats
|
|
lex.scan.IsIdentRune = parser.isIdentRune
|
|
lex.next() // initial lookahead
|
|
parser.expression = nil
|
|
parser.variable = make(map[string]struct{})
|
|
e := parser.parseExpr(lex)
|
|
if lex.token != scanner.EOF {
|
|
return fmt.Errorf("unexpected %s", lex.describe())
|
|
}
|
|
parser.expression = e
|
|
return nil
|
|
}
|
|
|
|
func (parser *ExpressionParser) GetExpr() Expr {
|
|
return parser.expression
|
|
}
|
|
|
|
func (parser *ExpressionParser) GetVariable() []string {
|
|
variable := make([]string, 0, len(parser.variable))
|
|
for v := range parser.variable {
|
|
if v != "true" && v != "false" {
|
|
variable = append(variable, v)
|
|
}
|
|
}
|
|
return variable
|
|
}
|
|
|
|
func (parser *ExpressionParser) isIdentRune(ch rune, i int) bool {
|
|
return ch == '$' || ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && i > 0
|
|
}
|
|
|
|
func (parser *ExpressionParser) parseExpr(lex *lexer) Expr {
|
|
return parser.parseBinary(lex, 1)
|
|
}
|
|
|
|
// binary = unary ('+' binary)*
|
|
// parseBinary stops when it encounters an
|
|
// operator of lower precedence than prec1.
|
|
func (parser *ExpressionParser) parseBinary(lex *lexer, prec1 int) Expr {
|
|
lhs := parser.parseUnary(lex)
|
|
for prec := precedence(lex.getToken(), lex.getText()); prec >= prec1; prec-- {
|
|
for precedence(lex.getToken(), lex.getText()) == prec {
|
|
op := lex.getText()
|
|
lex.next() // consume operator
|
|
rhs := parser.parseBinary(lex, prec+1)
|
|
lhs = binary{op, lhs, rhs}
|
|
}
|
|
}
|
|
return lhs
|
|
}
|
|
|
|
// unary = '+' expr | primary
|
|
func (parser *ExpressionParser) parseUnary(lex *lexer) Expr {
|
|
if lex.getToken() == TokenOperator {
|
|
op := lex.getText()
|
|
if op == "+" || op == "-" || op == "~" || op == "!" {
|
|
lex.next()
|
|
return unary{op, parser.parseUnary(lex)}
|
|
} else {
|
|
msg := fmt.Sprintf("unary got unknown operator:%s", lex.getText())
|
|
panic(lexPanic(msg))
|
|
}
|
|
}
|
|
return parser.parsePrimary(lex)
|
|
}
|
|
|
|
// primary = id
|
|
// | id '(' expr ',' ... ',' expr ')'
|
|
// | num
|
|
// | '(' expr ')'
|
|
func (parser *ExpressionParser) parsePrimary(lex *lexer) Expr {
|
|
switch lex.token {
|
|
case TokenIdent:
|
|
id := lex.getText()
|
|
lex.next()
|
|
if lex.token != '(' {
|
|
parser.variable[id] = struct{}{}
|
|
return Var(id)
|
|
}
|
|
lex.next() // consume '('
|
|
var args []Expr
|
|
if lex.token != ')' {
|
|
for {
|
|
args = append(args, parser.parseExpr(lex))
|
|
if lex.token != ',' {
|
|
break
|
|
}
|
|
lex.next() // consume ','
|
|
}
|
|
if lex.token != ')' {
|
|
msg := fmt.Sprintf("got %q, want ')'", lex.token)
|
|
panic(lexPanic(msg))
|
|
}
|
|
}
|
|
lex.next() // consume ')'
|
|
return call{id, args}
|
|
|
|
case TokenFloat:
|
|
f, err := strconv.ParseFloat(lex.getText(), 64)
|
|
if err != nil {
|
|
panic(lexPanic(err.Error()))
|
|
}
|
|
lex.next() // consume number
|
|
return literal{value: f}
|
|
|
|
case TokenInt:
|
|
i, err := strconv.ParseInt(lex.getText(), 10, 64)
|
|
if err != nil {
|
|
panic(lexPanic(err.Error()))
|
|
}
|
|
lex.next() // consume number
|
|
return literal{value: i}
|
|
|
|
case '(':
|
|
lex.next() // consume '('
|
|
e := parser.parseExpr(lex)
|
|
if lex.token != ')' {
|
|
msg := fmt.Sprintf("got %s, want ')'", lex.describe())
|
|
panic(lexPanic(msg))
|
|
}
|
|
lex.next() // consume ')'
|
|
return e
|
|
}
|
|
msg := fmt.Sprintf("unexpected %s", lex.describe())
|
|
panic(lexPanic(msg))
|
|
}
|