diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 00000000..13566b81
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/go-tree-sitter.iml b/.idea/go-tree-sitter.iml
new file mode 100644
index 00000000..5e764c4f
--- /dev/null
+++ b/.idea/go-tree-sitter.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 00000000..12e6436c
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 00000000..35eb1ddf
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/bindings.go b/bindings.go
index 1c241895..fa4985f5 100644
--- a/bindings.go
+++ b/bindings.go
@@ -1,6 +1,6 @@
package sitter
-//#include "bindings.h"
+// #include "bindings.h"
import "C"
import (
@@ -801,44 +801,47 @@ func NewQuery(pattern []byte, lang *Language) (*Query, error) {
q := &Query{c: c}
+ // this is just used for syntax validation - it does not actually filter anything
for i := uint32(0); i < q.PatternCount(); i++ {
- steps := q.PredicatesForPattern(i)
- if len(steps) == 0 {
- continue
- }
-
- if steps[0].Type != QueryPredicateStepTypeString {
- return nil, errors.New("predicate must begin with a literal value")
- }
-
- operator := q.StringValueForId(steps[0].ValueId)
- switch operator {
- case "eq?", "not-eq?":
- if len(steps) != 4 {
- return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
- }
- if steps[1].Type != QueryPredicateStepTypeCapture {
- return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
- }
- case "match?", "not-match?":
- if len(steps) != 4 {
- return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
- }
- if steps[1].Type != QueryPredicateStepTypeCapture {
- return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
- }
- if steps[2].Type != QueryPredicateStepTypeString {
- return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
- }
- case "set!", "is?", "is-not?":
- if len(steps) < 3 || len(steps) > 4 {
- return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
+ predicates := q.PredicatesForPattern(i)
+ for _, steps := range predicates {
+ if len(steps) == 0 {
+ continue
}
- if steps[1].Type != QueryPredicateStepTypeString {
- return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))
+
+ if steps[0].Type != QueryPredicateStepTypeString {
+ return nil, errors.New("predicate must begin with a literal value")
}
- if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
- return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
+
+ operator := q.StringValueForId(steps[0].ValueId)
+ switch operator {
+ case "eq?", "not-eq?":
+ if len(steps) != 4 {
+ return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
+ }
+ if steps[1].Type != QueryPredicateStepTypeCapture {
+ return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
+ }
+ case "match?", "not-match?":
+ if len(steps) != 4 {
+ return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
+ }
+ if steps[1].Type != QueryPredicateStepTypeCapture {
+ return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
+ }
+ if steps[2].Type != QueryPredicateStepTypeString {
+ return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
+ }
+ case "set!", "is?", "is-not?":
+ if len(steps) < 3 || len(steps) > 4 {
+ return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
+ }
+ if steps[1].Type != QueryPredicateStepTypeString {
+ return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))
+ }
+ if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
+ return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
+ }
}
}
}
@@ -885,7 +888,7 @@ type QueryPredicateStep struct {
ValueId uint32
}
-func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep {
+func (q *Query) PredicatesForPattern(patternIndex uint32) [][]QueryPredicateStep {
var (
length C.uint32_t
cPredicateSteps []C.TSQueryPredicateStep
@@ -905,7 +908,7 @@ func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep {
predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId})
}
- return predicateSteps
+ return splitPredicates(predicateSteps)
}
func (q *Query) CaptureNameForId(id uint32) string {
@@ -1059,6 +1062,19 @@ func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) {
return qm, uint32(captureIndex), true
}
+func splitPredicates(steps []QueryPredicateStep) [][]QueryPredicateStep {
+ var predicateSteps [][]QueryPredicateStep
+ var currentSteps []QueryPredicateStep
+ for _, step := range steps {
+ currentSteps = append(currentSteps, step)
+ if step.Type == QueryPredicateStepTypeDone {
+ predicateSteps = append(predicateSteps, currentSteps)
+ currentSteps = []QueryPredicateStep{}
+ }
+ }
+ return predicateSteps
+}
+
func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch {
qm := &QueryMatch{
ID: m.ID,
@@ -1067,87 +1083,90 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch
q := qc.q
- steps := q.PredicatesForPattern(uint32(qm.PatternIndex))
- if len(steps) == 0 {
+ predicates := q.PredicatesForPattern(uint32(qm.PatternIndex))
+ if len(predicates) == 0 {
qm.Captures = m.Captures
return qm
}
- operator := q.StringValueForId(steps[0].ValueId)
+ // track if we matched all predicates globally
+ matchedAll := true
- switch operator {
- case "eq?", "not-eq?":
- isPositive := operator == "eq?"
+ // check each predicate against the match
+ for _, steps := range predicates {
+ operator := q.StringValueForId(steps[0].ValueId)
- expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)
+ switch operator {
+ case "eq?", "not-eq?":
+ isPositive := operator == "eq?"
- if steps[2].Type == QueryPredicateStepTypeCapture {
- expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)
+ expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)
- var nodeLeft, nodeRight *Node
+ if steps[2].Type == QueryPredicateStepTypeCapture {
+ expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)
- found := false
+ var nodeLeft, nodeRight *Node
- for _, c := range m.Captures {
- captureName := q.CaptureNameForId(c.Index)
- qm.Captures = append(qm.Captures, c)
+ for _, c := range m.Captures {
+ captureName := q.CaptureNameForId(c.Index)
- if captureName == expectedCaptureNameLeft {
- nodeLeft = c.Node
- }
- if captureName == expectedCaptureNameRight {
- nodeRight = c.Node
+ if captureName == expectedCaptureNameLeft {
+ nodeLeft = c.Node
+ }
+ if captureName == expectedCaptureNameRight {
+ nodeRight = c.Node
+ }
+
+ if nodeLeft != nil && nodeRight != nil {
+ if (nodeLeft.Content(input) == nodeRight.Content(input)) != isPositive {
+ matchedAll = false
+ }
+ break
+ }
}
+ } else {
+ expectedValueRight := q.StringValueForId(steps[2].ValueId)
- if nodeLeft != nil && nodeRight != nil {
- if (nodeLeft.Content(input) == nodeRight.Content(input)) == isPositive {
- found = true
+ for _, c := range m.Captures {
+ captureName := q.CaptureNameForId(c.Index)
+
+ if expectedCaptureNameLeft != captureName {
+ continue
+ }
+
+ if (c.Node.Content(input) == expectedValueRight) != isPositive {
+ matchedAll = false
+ break
}
- break
}
}
- if !found {
- qm.Captures = nil
+ if matchedAll == false {
+ break
}
- } else {
- expectedValueRight := q.StringValueForId(steps[2].ValueId)
- found := false
+ case "match?", "not-match?":
+ isPositive := operator == "match?"
+
+ expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
+ regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))
+
for _, c := range m.Captures {
captureName := q.CaptureNameForId(c.Index)
-
- qm.Captures = append(qm.Captures, c)
- if expectedCaptureNameLeft != captureName {
+ if expectedCaptureName != captureName {
continue
}
- if (c.Node.Content(input) == expectedValueRight) == isPositive {
- found = true
+ if regex.Match([]byte(c.Node.Content(input))) != isPositive {
+ matchedAll = false
+ break
}
}
-
- if !found {
- qm.Captures = nil
- }
}
+ }
- case "match?", "not-match?":
- isPositive := operator == "match?"
-
- expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
- regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))
-
- for _, c := range m.Captures {
- captureName := q.CaptureNameForId(c.Index)
- if expectedCaptureName != captureName {
- continue
- }
-
- if regex.Match([]byte(c.Node.Content(input))) == isPositive {
- qm.Captures = append(qm.Captures, c)
- }
- }
+ if matchedAll {
+ qm.Captures = append(qm.Captures, m.Captures...)
}
return qm
diff --git a/predicates_test.go b/predicates_test.go
index cabe7e84..c985b677 100644
--- a/predicates_test.go
+++ b/predicates_test.go
@@ -96,6 +96,12 @@ func TestQueryWithPredicates(t *testing.T) {
msg: "#eq?: success test",
pattern: `((expression) @capture
(#eq? @capture "this"))`,
+ },
+ {
+ success: true,
+ msg: "#eq?: success double predicate test",
+ pattern: `((expression) @capture
+ (#eq? @capture @capture) (#eq? @capture "this"))`,
},
{
success: true,
@@ -287,6 +293,13 @@ func TestFilterPredicates(t *testing.T) {
expectedBefore: 1,
expectedAfter: 0,
},
+ {
+ input: `// foo`,
+ query: `((comment) @capture
+ (#eq? @capture "// foo") (#eq? @capture "// bar"))`,
+ expectedBefore: 1,
+ expectedAfter: 0,
+ },
{
input: `1234 + 1234`,
query: `((sum
@@ -346,6 +359,24 @@ func TestFilterPredicates(t *testing.T) {
expectedBefore: 2,
expectedAfter: 2,
},
+ {
+ input: `1234 + 4321`,
+ query: `((sum
+ left: (expression (number) @left)
+ right: (expression (number) @right))
+ (#eq? @left 1234) (#not-eq? @left @right))`,
+ expectedBefore: 2,
+ expectedAfter: 2,
+ },
+ {
+ input: `1234 + 4321`,
+ query: `((sum
+ left: (expression (number) @left)
+ right: (expression (number) @right))
+ (#eq? @left 1234) (#eq? @left 4321))`,
+ expectedBefore: 2,
+ expectedAfter: 0,
+ },
}
parser := NewParser()