Skip to content

Commit

Permalink
Merge pull request #102 from wesen/misc/improve-query-handling
Browse files Browse the repository at this point in the history
Improve query handling and fix predicates filtering
  • Loading branch information
smacker authored May 1, 2023
2 parents af16b10 + a896a22 commit a7d9277
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 37 deletions.
162 changes: 128 additions & 34 deletions bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"unsafe"
Expand Down Expand Up @@ -678,36 +679,39 @@ const (
QueryErrorNodeType
QueryErrorField
QueryErrorCapture
QueryErrorStructure
QueryErrorLanguage
)

func QueryErrorTypeToString(errorType QueryErrorType) string {
switch errorType {
case QueryErrorNone:
return "none"
case QueryErrorNodeType:
return "node type"
case QueryErrorField:
return "field"
case QueryErrorCapture:
return "capture"
case QueryErrorSyntax:
return "syntax"
default:
return "unknown"
}

}

// QueryError - if there is an error in the query,
// then the Offset argument will be set to the byte offset of the error,
// and the Type argument will be set to a value that indicates the type of error.
type QueryError struct {
Offset uint32
Type QueryErrorType
Offset uint32
Type QueryErrorType
Message string
}

func (qe *QueryError) Error() string {
switch qe.Type {
case QueryErrorNone:
return ""

case QueryErrorSyntax:
return fmt.Sprintf("syntax error (offset: %d)", qe.Offset)

case QueryErrorNodeType:
return fmt.Sprintf("node type error (offset: %d)", qe.Offset)

case QueryErrorField:
return fmt.Sprintf("field error (offset: %d)", qe.Offset)

case QueryErrorCapture:
return fmt.Sprintf("capture error (offset: %d)", qe.Offset)

default:
return fmt.Sprintf("unknown error (offset: %d)", qe.Offset)
}
return qe.Message
}

// Query API
Expand All @@ -734,7 +738,65 @@ func NewQuery(pattern []byte, lang *Language) (*Query, error) {
)
C.free(input)
if errtype != C.TSQueryError(QueryErrorNone) {
return nil, &QueryError{Offset: uint32(erroff), Type: QueryErrorType(errtype)}
errorOffset := uint32(erroff)
// search for the line containing the offset
line := 1
line_start := 0
for i, c := range pattern {
line_start = i
if uint32(i) >= errorOffset {
break
}
if c == '\n' {
line++
}
}
column := int(errorOffset) - line_start
errorType := QueryErrorType(errtype)
errorTypeToString := QueryErrorTypeToString(errorType)

var message string
switch errorType {
// errors that apply to a single identifier
case QueryErrorNodeType:
fallthrough
case QueryErrorField:
fallthrough
case QueryErrorCapture:
// find identifier at input[errorOffset]
// and report it in the error message
s := string(pattern[errorOffset:])
identifierRegexp := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_-]*`)
m := identifierRegexp.FindStringSubmatch(s)
if len(m) > 0 {
message = fmt.Sprintf("invalid %s '%s' at line %d column %d",
errorTypeToString, m[0], line, column)
} else {
message = fmt.Sprintf("invalid %s at line %d column %d",
errorTypeToString, line, column)
}

// errors the report position
case QueryErrorSyntax:
fallthrough
case QueryErrorStructure:
fallthrough
case QueryErrorLanguage:
fallthrough
default:
s := string(pattern[errorOffset:])
lines := strings.Split(s, "\n")
whitespace := strings.Repeat(" ", column)
message = fmt.Sprintf("invalid %s at line %d column %d\n%s\n%s^",
errorTypeToString, line, column,
lines[0], whitespace)
}

return nil, &QueryError{
Offset: errorOffset,
Type: errorType,
Message: message,
}
}

q := &Query{c: c}
Expand Down Expand Up @@ -858,6 +920,20 @@ func (q *Query) StringValueForId(id uint32) string {
return C.GoStringN(value, C.int(length))
}

type Quantifier int

const (
QuantifierZero = iota
QuantifierZeroOrOne
QuantifierZeroOrMore
QuantifierOne
QuantifierOneOrMore
)

func (q *Query) CaptureQuantifierForId(id uint32, captureId uint32) Quantifier {
return Quantifier(C.ts_query_capture_quantifier_for_id(q.c, C.uint32_t(id), C.uint32_t(captureId)))
}

// QueryCursor carries the state needed for processing the queries.
type QueryCursor struct {
c *C.TSQueryCursor
Expand Down Expand Up @@ -989,27 +1065,32 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch
PatternIndex: m.PatternIndex,
}

steps := qc.q.PredicatesForPattern(uint32(qm.PatternIndex))
q := qc.q

steps := q.PredicatesForPattern(uint32(qm.PatternIndex))
if len(steps) == 0 {
qm.Captures = m.Captures
return qm
}

operator := qc.q.StringValueForId(steps[0].ValueId)
operator := q.StringValueForId(steps[0].ValueId)

switch operator {
case "eq?", "not-eq?":
isPositive := operator == "eq?"

expectedCaptureNameLeft := qc.q.CaptureNameForId(steps[1].ValueId)
expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)

if steps[2].Type == QueryPredicateStepTypeCapture {
expectedCaptureNameRight := qc.q.CaptureNameForId(steps[2].ValueId)
expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)

var nodeLeft, nodeRight *Node

found := false

for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)
qm.Captures = append(qm.Captures, c)

if captureName == expectedCaptureNameLeft {
nodeLeft = c.Node
Expand All @@ -1020,33 +1101,45 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch

if nodeLeft != nil && nodeRight != nil {
if (nodeLeft.Content(input) == nodeRight.Content(input)) == isPositive {
qm.Captures = append(qm.Captures, c)
found = true
}
break
}
}

if !found {
qm.Captures = nil
}
} else {
expectedValueRight := qc.q.StringValueForId(steps[2].ValueId)
expectedValueRight := q.StringValueForId(steps[2].ValueId)

found := false
for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)

qm.Captures = append(qm.Captures, c)
if expectedCaptureNameLeft != captureName {
continue
}

if (c.Node.Content(input) == expectedValueRight) == isPositive {
qm.Captures = append(qm.Captures, c)
found = true
}
}

if !found {
qm.Captures = nil
}
}

case "match?", "not-match?":
isPositive := operator == "match?"

expectedCaptureName := qc.q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(qc.q.StringValueForId(steps[2].ValueId))
expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))

This comment has been minimized.

Copy link
@sgtroy88

sgtroy88 Jun 4, 2023

Test


for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)
if expectedCaptureName != captureName {
continue
}
Expand All @@ -1058,6 +1151,7 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch
}

return qm

}

// keeps callbacks for parser.parse method
Expand Down
3 changes: 2 additions & 1 deletion bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ func TestQueryError(t *testing.T) {

assert.Nil(q)
assert.NotNil(err)
assert.EqualValues(&QueryError{Offset: 0x02, Type: QueryErrorNodeType}, err)
assert.EqualValues(&QueryError{Offset: 0x02, Type: QueryErrorNodeType,
Message: "invalid node type 'unknown' at line 1 column 0"}, err)
}

func doWorkLifetime(t testing.TB, n *Node) {
Expand Down
13 changes: 11 additions & 2 deletions predicates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func TestFilterPredicates(t *testing.T) {
right: (expression (number) @right))
(#eq? @left @right))`,
expectedBefore: 2,
expectedAfter: 1,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
Expand Down Expand Up @@ -335,7 +335,16 @@ func TestFilterPredicates(t *testing.T) {
right: (expression (number) @right))
(#not-eq? @left @right))`,
expectedBefore: 2,
expectedAfter: 1,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
query: `((sum
left: (expression (number) @left)
right: (expression (number) @right))
(#eq? @left 1234))`,
expectedBefore: 2,
expectedAfter: 2,
},
}

Expand Down

0 comments on commit a7d9277

Please sign in to comment.