diff --git a/lua/treewalker/nodes.lua b/lua/treewalker/nodes.lua index 688d831..512493c 100644 --- a/lua/treewalker/nodes.lua +++ b/lua/treewalker/nodes.lua @@ -1,32 +1,43 @@ local util = require "treewalker.util" -local lines= require "treewalker.lines" +local lines = require "treewalker.lines" -local TARGET_TYPE_BLACKLIST = { - "^.*comment.*$", +-- These are regexes but just happen to be real simple so far +local TARGET_BLACKLIST_TYPE_MATCHERS = { + "comment", } -local HIGHLIGHT_BLACKLIST_TYPES = { - "body_statement", -- lua, rb - "block", -- lua - "statement_block", -- lua - "program", -- rb +local HIGHLIGHT_BLACKLIST_TYPE_MATCHERS = { + "chunk", + "body", + "block", + "program", } + local M = {} ---@param node TSNode +---@param matchers string[] ---@return boolean -function M.is_jump_target(node) - for _, matcher in ipairs(TARGET_TYPE_BLACKLIST) do +local function is_matched_in(node, matchers) + for _, matcher in ipairs(matchers) do if node:type():match(matcher) then - return false + return true end end - return true + return false +end + +---@param node TSNode +---@return boolean +function M.is_jump_target(node) + return not is_matched_in(node, TARGET_BLACKLIST_TYPE_MATCHERS) end +---@param node TSNode +---@return boolean function M.is_highlight_target(node) - return util.contains(HIGHLIGHT_BLACKLIST_TYPES, node:type()) + return not is_matched_in(node, HIGHLIGHT_BLACKLIST_TYPE_MATCHERS) end ---Do the nodes have the same starting point @@ -79,62 +90,53 @@ end ---@param node TSNode ---@return TSNode[] function M.get_descendants(node) - local descendants = {} - - -- Helper function to recursively collect descendants - local function collect_descendants(current_node) - local child_count = current_node:child_count() - for i = 0, child_count - 1 do - local child = current_node:child(i) - table.insert(descendants, child) - -- Recursively collect descendants of the child - collect_descendants(child) - end + local descendants = {} + + -- Helper function to recursively collect descendants + local function collect_descendants(current_node) + local child_count = current_node:child_count() + for i = 0, child_count - 1 do + local child = current_node:child(i) + table.insert(descendants, child) + -- Recursively collect descendants of the child + collect_descendants(child) end + end - -- Start the recursive collection with the given node - collect_descendants(node) + -- Start the recursive collection with the given node + collect_descendants(node) - return descendants + return descendants end +-- Get farthest ancestor (or self) at the same starting coordinates ---@param node TSNode ---@return TSNode function M.get_farthest_ancestor_with_same_srow(node) - local node_row = node:range() - local farthest_ancestor = node - local iter_row = node:range() - local iter = node:parent() - - - while iter do - iter_row = iter:range() - if iter_row ~= node_row then - break - end - farthest_ancestor = iter - iter = iter:parent() + local parent = node:parent() + while parent and M.have_same_start(node, parent) do + if M.is_highlight_target(parent) then node = parent end + parent = parent:parent() end - - return farthest_ancestor + return node end --- Take a list of nodes and unique them based on line start ---@param nodes TSNode[] ---@return TSNode[] function M.unique_per_line(nodes) - local unique_nodes = {} - local seen_lines = {} - - for _, node in ipairs(nodes) do - local line = node:start() -- Assuming node:start() returns the line number of the node - if not seen_lines[line] then - table.insert(unique_nodes, node) - seen_lines[line] = true - end + local unique_nodes = {} + local seen_lines = {} + + for _, node in ipairs(nodes) do + local line = node:start() -- Assuming node:start() returns the line number of the node + if not seen_lines[line] then + table.insert(unique_nodes, node) + seen_lines[line] = true end + end - return unique_nodes + return unique_nodes end -- Easy conversion to table diff --git a/lua/treewalker/ops.lua b/lua/treewalker/ops.lua index 488e397..7d25706 100644 --- a/lua/treewalker/ops.lua +++ b/lua/treewalker/ops.lua @@ -63,12 +63,7 @@ function M.jump(row, node) vim.api.nvim_win_set_cursor(0, { row, 0 }) vim.cmd('normal! ^') if require("treewalker").opts.highlight then - -- Get farthest ancestor (or self) at the same starting coordinates - local parent = node:parent() - while parent and nodes.have_same_start(node, parent) do - if nodes.is_highlight_target(parent) then node = parent end - parent = parent:parent() - end + node = nodes.get_farthest_ancestor_with_same_srow(node) M.highlight(nodes.range(node)) end end