Skip to content

Commit

Permalink
Implemented multi-predicate support #103
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-ulrich1 committed May 28, 2023
1 parent a7d9277 commit fba87dc
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 91 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions .idea/go-tree-sitter.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

201 changes: 110 additions & 91 deletions bindings.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sitter

//#include "bindings.h"
// #include "bindings.h"
import "C"

import (
Expand Down Expand Up @@ -801,44 +801,47 @@ func NewQuery(pattern []byte, lang *Language) (*Query, error) {

q := &Query{c: c}

// this is just used for syntax validation - it does not actually filter anything
for i := uint32(0); i < q.PatternCount(); i++ {
steps := q.PredicatesForPattern(i)
if len(steps) == 0 {
continue
}

if steps[0].Type != QueryPredicateStepTypeString {
return nil, errors.New("predicate must begin with a literal value")
}

operator := q.StringValueForId(steps[0].ValueId)
switch operator {
case "eq?", "not-eq?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
case "match?", "not-match?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
if steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
}
case "set!", "is?", "is-not?":
if len(steps) < 3 || len(steps) > 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
predicates := q.PredicatesForPattern(i)
for _, steps := range predicates {
if len(steps) == 0 {
continue
}
if steps[1].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))

if steps[0].Type != QueryPredicateStepTypeString {
return nil, errors.New("predicate must begin with a literal value")
}
if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))

operator := q.StringValueForId(steps[0].ValueId)
switch operator {
case "eq?", "not-eq?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
case "match?", "not-match?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
if steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
}
case "set!", "is?", "is-not?":
if len(steps) < 3 || len(steps) > 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
}
}
}
}
Expand Down Expand Up @@ -885,7 +888,7 @@ type QueryPredicateStep struct {
ValueId uint32
}

func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep {
func (q *Query) PredicatesForPattern(patternIndex uint32) [][]QueryPredicateStep {
var (
length C.uint32_t
cPredicateSteps []C.TSQueryPredicateStep
Expand All @@ -905,7 +908,7 @@ func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep {
predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId})
}

return predicateSteps
return splitPredicates(predicateSteps)
}

func (q *Query) CaptureNameForId(id uint32) string {
Expand Down Expand Up @@ -1059,6 +1062,19 @@ func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) {
return qm, uint32(captureIndex), true
}

func splitPredicates(steps []QueryPredicateStep) [][]QueryPredicateStep {
var predicateSteps [][]QueryPredicateStep
var currentSteps []QueryPredicateStep
for _, step := range steps {
currentSteps = append(currentSteps, step)
if step.Type == QueryPredicateStepTypeDone {
predicateSteps = append(predicateSteps, currentSteps)
currentSteps = []QueryPredicateStep{}
}
}
return predicateSteps
}

func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch {
qm := &QueryMatch{
ID: m.ID,
Expand All @@ -1067,87 +1083,90 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch

q := qc.q

steps := q.PredicatesForPattern(uint32(qm.PatternIndex))
if len(steps) == 0 {
predicates := q.PredicatesForPattern(uint32(qm.PatternIndex))
if len(predicates) == 0 {
qm.Captures = m.Captures
return qm
}

operator := q.StringValueForId(steps[0].ValueId)
// track if we matched all predicates globally
matchedAll := true

switch operator {
case "eq?", "not-eq?":
isPositive := operator == "eq?"
// check each predicate against the match
for _, steps := range predicates {
operator := q.StringValueForId(steps[0].ValueId)

expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)
switch operator {
case "eq?", "not-eq?":
isPositive := operator == "eq?"

if steps[2].Type == QueryPredicateStepTypeCapture {
expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)
expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)

var nodeLeft, nodeRight *Node
if steps[2].Type == QueryPredicateStepTypeCapture {
expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)

found := false
var nodeLeft, nodeRight *Node

for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)
qm.Captures = append(qm.Captures, c)
for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)

if captureName == expectedCaptureNameLeft {
nodeLeft = c.Node
}
if captureName == expectedCaptureNameRight {
nodeRight = c.Node
if captureName == expectedCaptureNameLeft {
nodeLeft = c.Node
}
if captureName == expectedCaptureNameRight {
nodeRight = c.Node
}

if nodeLeft != nil && nodeRight != nil {
if (nodeLeft.Content(input) == nodeRight.Content(input)) != isPositive {
matchedAll = false
}
break
}
}
} else {
expectedValueRight := q.StringValueForId(steps[2].ValueId)

if nodeLeft != nil && nodeRight != nil {
if (nodeLeft.Content(input) == nodeRight.Content(input)) == isPositive {
found = true
for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)

if expectedCaptureNameLeft != captureName {
continue
}

if (c.Node.Content(input) == expectedValueRight) != isPositive {
matchedAll = false
break
}
break
}
}

if !found {
qm.Captures = nil
if matchedAll == false {
break
}
} else {
expectedValueRight := q.StringValueForId(steps[2].ValueId)

found := false
case "match?", "not-match?":
isPositive := operator == "match?"

expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))

for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)

qm.Captures = append(qm.Captures, c)
if expectedCaptureNameLeft != captureName {
if expectedCaptureName != captureName {
continue
}

if (c.Node.Content(input) == expectedValueRight) == isPositive {
found = true
if regex.Match([]byte(c.Node.Content(input))) != isPositive {
matchedAll = false
break
}
}

if !found {
qm.Captures = nil
}
}
}

case "match?", "not-match?":
isPositive := operator == "match?"

expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))

for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)
if expectedCaptureName != captureName {
continue
}

if regex.Match([]byte(c.Node.Content(input))) == isPositive {
qm.Captures = append(qm.Captures, c)
}
}
if matchedAll {
qm.Captures = append(qm.Captures, m.Captures...)
}

return qm
Expand Down
31 changes: 31 additions & 0 deletions predicates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func TestQueryWithPredicates(t *testing.T) {
msg: "#eq?: success test",
pattern: `((expression) @capture
(#eq? @capture "this"))`,
},
{
success: true,
msg: "#eq?: success double predicate test",
pattern: `((expression) @capture
(#eq? @capture @capture) (#eq? @capture "this"))`,
},
{
success: true,
Expand Down Expand Up @@ -287,6 +293,13 @@ func TestFilterPredicates(t *testing.T) {
expectedBefore: 1,
expectedAfter: 0,
},
{
input: `// foo`,
query: `((comment) @capture
(#eq? @capture "// foo") (#eq? @capture "// bar"))`,
expectedBefore: 1,
expectedAfter: 0,
},
{
input: `1234 + 1234`,
query: `((sum
Expand Down Expand Up @@ -346,6 +359,24 @@ func TestFilterPredicates(t *testing.T) {
expectedBefore: 2,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
query: `((sum
left: (expression (number) @left)
right: (expression (number) @right))
(#eq? @left 1234) (#not-eq? @left @right))`,
expectedBefore: 2,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
query: `((sum
left: (expression (number) @left)
right: (expression (number) @right))
(#eq? @left 1234) (#eq? @left 4321))`,
expectedBefore: 2,
expectedAfter: 0,
},
}

parser := NewParser()
Expand Down

0 comments on commit fba87dc

Please sign in to comment.