diff --git a/_automation/treesitter_updater/main.go b/_automation/treesitter_updater/main.go index 8f160f86..95b3c911 100644 --- a/_automation/treesitter_updater/main.go +++ b/_automation/treesitter_updater/main.go @@ -16,7 +16,7 @@ import ( ) // Constants for the Tree Sitter version and download URL -const sitterVersion = "0.20.9" +const sitterVersion = "0.22.5" const sitterURL = "https://github.com/tree-sitter/tree-sitter/archive/refs/tags/v" + sitterVersion + ".tar.gz" func main() { diff --git a/alloc.c b/alloc.c index a6adaf5e..4bd2d8c6 100644 --- a/alloc.c +++ b/alloc.c @@ -1,4 +1,5 @@ #include "alloc.h" +#include "api.h" #include static void *ts_malloc_default(size_t size) { @@ -29,10 +30,10 @@ static void *ts_realloc_default(void *buffer, size_t size) { } // Allow clients to override allocation functions dynamically -void *(*ts_current_malloc)(size_t) = ts_malloc_default; -void *(*ts_current_calloc)(size_t, size_t) = ts_calloc_default; -void *(*ts_current_realloc)(void *, size_t) = ts_realloc_default; -void (*ts_current_free)(void *) = free; +TS_PUBLIC void *(*ts_current_malloc)(size_t) = ts_malloc_default; +TS_PUBLIC void *(*ts_current_calloc)(size_t, size_t) = ts_calloc_default; +TS_PUBLIC void *(*ts_current_realloc)(void *, size_t) = ts_realloc_default; +TS_PUBLIC void (*ts_current_free)(void *) = free; void ts_set_allocator( void *(*new_malloc)(size_t size), diff --git a/alloc.h b/alloc.h index c2c63353..a0eadb7a 100644 --- a/alloc.h +++ b/alloc.h @@ -1,20 +1,24 @@ #ifndef TREE_SITTER_ALLOC_H_ #define TREE_SITTER_ALLOC_H_ -#include "api.h" - #ifdef __cplusplus extern "C" { #endif -#include #include #include +#include + +#if defined(TREE_SITTER_HIDDEN_SYMBOLS) || defined(_WIN32) +#define TS_PUBLIC +#else +#define TS_PUBLIC __attribute__((visibility("default"))) +#endif -extern void *(*ts_current_malloc)(size_t); -extern void *(*ts_current_calloc)(size_t, size_t); -extern void *(*ts_current_realloc)(void *, size_t); -extern void (*ts_current_free)(void *); +TS_PUBLIC extern void *(*ts_current_malloc)(size_t); +TS_PUBLIC extern void *(*ts_current_calloc)(size_t, size_t); +TS_PUBLIC extern void *(*ts_current_realloc)(void *, size_t); +TS_PUBLIC extern void (*ts_current_free)(void *); // Allow clients to override allocation functions #ifndef ts_malloc @@ -34,4 +38,4 @@ extern void (*ts_current_free)(void *); } #endif -#endif // TREE_SITTER_ALLOC_H_ +#endif // TREE_SITTER_ALLOC_H_ diff --git a/api.h b/api.h index eeecf317..de122289 100644 --- a/api.h +++ b/api.h @@ -1,9 +1,11 @@ #ifndef TREE_SITTER_API_H_ #define TREE_SITTER_API_H_ +#ifndef TREE_SITTER_HIDE_SYMBOLS #if defined(__GNUC__) || defined(__clang__) #pragma GCC visibility push(default) #endif +#endif #ifdef __cplusplus extern "C" { @@ -46,46 +48,46 @@ typedef struct TSQuery TSQuery; typedef struct TSQueryCursor TSQueryCursor; typedef struct TSLookaheadIterator TSLookaheadIterator; -typedef enum { +typedef enum TSInputEncoding { TSInputEncodingUTF8, TSInputEncodingUTF16, } TSInputEncoding; -typedef enum { +typedef enum TSSymbolType { TSSymbolTypeRegular, TSSymbolTypeAnonymous, TSSymbolTypeAuxiliary, } TSSymbolType; -typedef struct { +typedef struct TSPoint { uint32_t row; uint32_t column; } TSPoint; -typedef struct { +typedef struct TSRange { TSPoint start_point; TSPoint end_point; uint32_t start_byte; uint32_t end_byte; } TSRange; -typedef struct { +typedef struct TSInput { void *payload; const char *(*read)(void *payload, uint32_t byte_index, TSPoint position, uint32_t *bytes_read); TSInputEncoding encoding; } TSInput; -typedef enum { +typedef enum TSLogType { TSLogTypeParse, TSLogTypeLex, } TSLogType; -typedef struct { +typedef struct TSLogger { void *payload; void (*log)(void *payload, TSLogType log_type, const char *buffer); } TSLogger; -typedef struct { +typedef struct TSInputEdit { uint32_t start_byte; uint32_t old_end_byte; uint32_t new_end_byte; @@ -94,24 +96,24 @@ typedef struct { TSPoint new_end_point; } TSInputEdit; -typedef struct { +typedef struct TSNode { uint32_t context[4]; const void *id; const TSTree *tree; } TSNode; -typedef struct { +typedef struct TSTreeCursor { const void *tree; const void *id; - uint32_t context[2]; + uint32_t context[3]; } TSTreeCursor; -typedef struct { +typedef struct TSQueryCapture { TSNode node; uint32_t index; } TSQueryCapture; -typedef enum { +typedef enum TSQuantifier { TSQuantifierZero = 0, // must match the array initialization value TSQuantifierZeroOrOne, TSQuantifierZeroOrMore, @@ -119,25 +121,25 @@ typedef enum { TSQuantifierOneOrMore, } TSQuantifier; -typedef struct { +typedef struct TSQueryMatch { uint32_t id; uint16_t pattern_index; uint16_t capture_count; const TSQueryCapture *captures; } TSQueryMatch; -typedef enum { +typedef enum TSQueryPredicateStepType { TSQueryPredicateStepTypeDone, TSQueryPredicateStepTypeCapture, TSQueryPredicateStepTypeString, } TSQueryPredicateStepType; -typedef struct { +typedef struct TSQueryPredicateStep { TSQueryPredicateStepType type; uint32_t value_id; } TSQueryPredicateStep; -typedef enum { +typedef enum TSQueryError { TSQueryErrorNone = 0, TSQueryErrorSyntax, TSQueryErrorNodeType, @@ -1013,6 +1015,17 @@ void ts_query_cursor_set_max_start_depth(TSQueryCursor *self, uint32_t max_start /* Section - Language */ /**********************/ +/** + * Get another reference to the given language. + */ +const TSLanguage *ts_language_copy(const TSLanguage *self); + +/** + * Free any dynamically-allocated resources for this language, if + * this is the last reference. + */ +void ts_language_delete(const TSLanguage *self); + /** * Get the number of distinct node types in the language. */ @@ -1190,9 +1203,14 @@ const TSLanguage *ts_wasm_store_load_language( TSWasmError *error ); +/** + * Get the number of languages instantiated in the given wasm store. + */ +size_t ts_wasm_store_language_count(const TSWasmStore *); + /** * Check if the language came from a Wasm module. If so, then in order to use - * this langauge with a Parser, that parser must have a Wasm store assigned. + * this language with a Parser, that parser must have a Wasm store assigned. */ bool ts_language_is_wasm(const TSLanguage *); @@ -1239,8 +1257,10 @@ void ts_set_allocator( } #endif +#ifndef TREE_SITTER_HIDE_SYMBOLS #if defined(__GNUC__) || defined(__clang__) #pragma GCC visibility pop #endif +#endif #endif // TREE_SITTER_API_H_ diff --git a/array.h b/array.h index e026f6b2..15a3b233 100644 --- a/array.h +++ b/array.h @@ -5,12 +5,20 @@ extern "C" { #endif -#include -#include -#include +#include "./alloc.h" + #include #include -#include "./alloc.h" +#include +#include +#include + +#ifdef _MSC_VER +#pragma warning(disable : 4101) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#endif #define Array(T) \ struct { \ @@ -19,94 +27,115 @@ extern "C" { uint32_t capacity; \ } +/// Initialize an array. #define array_init(self) \ ((self)->size = 0, (self)->capacity = 0, (self)->contents = NULL) +/// Create an empty array. #define array_new() \ { NULL, 0, 0 } +/// Get a pointer to the element at a given `index` in the array. #define array_get(self, _index) \ (assert((uint32_t)(_index) < (self)->size), &(self)->contents[_index]) +/// Get a pointer to the first element in the array. #define array_front(self) array_get(self, 0) +/// Get a pointer to the last element in the array. #define array_back(self) array_get(self, (self)->size - 1) +/// Clear the array, setting its size to zero. Note that this does not free any +/// memory allocated for the array's contents. #define array_clear(self) ((self)->size = 0) +/// Reserve `new_capacity` elements of space in the array. If `new_capacity` is +/// less than the array's current capacity, this function has no effect. #define array_reserve(self, new_capacity) \ - array__reserve((VoidArray *)(self), array__elem_size(self), new_capacity) + _array__reserve((Array *)(self), array_elem_size(self), new_capacity) -// Free any memory allocated for this array. -#define array_delete(self) array__delete((VoidArray *)(self)) +/// Free any memory allocated for this array. Note that this does not free any +/// memory allocated for the array's contents. +#define array_delete(self) _array__delete((Array *)(self)) +/// Push a new `element` onto the end of the array. #define array_push(self, element) \ - (array__grow((VoidArray *)(self), 1, array__elem_size(self)), \ + (_array__grow((Array *)(self), 1, array_elem_size(self)), \ (self)->contents[(self)->size++] = (element)) -// Increase the array's size by a given number of elements, reallocating -// if necessary. New elements are zero-initialized. +/// Increase the array's size by `count` elements. +/// New elements are zero-initialized. #define array_grow_by(self, count) \ - (array__grow((VoidArray *)(self), count, array__elem_size(self)), \ - memset((self)->contents + (self)->size, 0, (count) * array__elem_size(self)), \ - (self)->size += (count)) + do { \ + if ((count) == 0) break; \ + _array__grow((Array *)(self), count, array_elem_size(self)); \ + memset((self)->contents + (self)->size, 0, (count) * array_elem_size(self)); \ + (self)->size += (count); \ + } while (0) +/// Append all elements from one array to the end of another. #define array_push_all(self, other) \ array_extend((self), (other)->size, (other)->contents) -// Append `count` elements to the end of the array, reading their values from the -// `contents` pointer. +/// Append `count` elements to the end of the array, reading their values from the +/// `contents` pointer. #define array_extend(self, count, contents) \ - array__splice( \ - (VoidArray *)(self), array__elem_size(self), (self)->size, \ + _array__splice( \ + (Array *)(self), array_elem_size(self), (self)->size, \ 0, count, contents \ ) -// Remove `old_count` elements from the array starting at the given `index`. At -// the same index, insert `new_count` new elements, reading their values from the -// `new_contents` pointer. +/// Remove `old_count` elements from the array starting at the given `index`. At +/// the same index, insert `new_count` new elements, reading their values from the +/// `new_contents` pointer. #define array_splice(self, _index, old_count, new_count, new_contents) \ - array__splice( \ - (VoidArray *)(self), array__elem_size(self), _index, \ + _array__splice( \ + (Array *)(self), array_elem_size(self), _index, \ old_count, new_count, new_contents \ ) -// Insert one `element` into the array at the given `index`. +/// Insert one `element` into the array at the given `index`. #define array_insert(self, _index, element) \ - array__splice((VoidArray *)(self), array__elem_size(self), _index, 0, 1, &(element)) + _array__splice((Array *)(self), array_elem_size(self), _index, 0, 1, &(element)) -// Remove one `element` from the array at the given `index`. +/// Remove one element from the array at the given `index`. #define array_erase(self, _index) \ - array__erase((VoidArray *)(self), array__elem_size(self), _index) + _array__erase((Array *)(self), array_elem_size(self), _index) +/// Pop the last element off the array, returning the element by value. #define array_pop(self) ((self)->contents[--(self)->size]) +/// Assign the contents of one array to another, reallocating if necessary. #define array_assign(self, other) \ - array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) + _array__assign((Array *)(self), (const Array *)(other), array_elem_size(self)) +/// Swap one array with another #define array_swap(self, other) \ - array__swap((VoidArray *)(self), (VoidArray *)(other)) - -// Search a sorted array for a given `needle` value, using the given `compare` -// callback to determine the order. -// -// If an existing element is found to be equal to `needle`, then the `index` -// out-parameter is set to the existing value's index, and the `exists` -// out-parameter is set to true. Otherwise, `index` is set to an index where -// `needle` should be inserted in order to preserve the sorting, and `exists` -// is set to false. + _array__swap((Array *)(self), (Array *)(other)) + +/// Get the size of the array contents +#define array_elem_size(self) (sizeof *(self)->contents) + +/// Search a sorted array for a given `needle` value, using the given `compare` +/// callback to determine the order. +/// +/// If an existing element is found to be equal to `needle`, then the `index` +/// out-parameter is set to the existing value's index, and the `exists` +/// out-parameter is set to true. Otherwise, `index` is set to an index where +/// `needle` should be inserted in order to preserve the sorting, and `exists` +/// is set to false. #define array_search_sorted_with(self, compare, needle, _index, _exists) \ - array__search_sorted(self, 0, compare, , needle, _index, _exists) + _array__search_sorted(self, 0, compare, , needle, _index, _exists) -// Search a sorted array for a given `needle` value, using integer comparisons -// of a given struct field (specified with a leading dot) to determine the order. -// -// See also `array_search_sorted_with`. +/// Search a sorted array for a given `needle` value, using integer comparisons +/// of a given struct field (specified with a leading dot) to determine the order. +/// +/// See also `array_search_sorted_with`. #define array_search_sorted_by(self, field, needle, _index, _exists) \ - array__search_sorted(self, 0, compare_int, field, needle, _index, _exists) + _array__search_sorted(self, 0, _compare_int, field, needle, _index, _exists) -// Insert a given `value` into a sorted array, using the given `compare` -// callback to determine the order. +/// Insert a given `value` into a sorted array, using the given `compare` +/// callback to determine the order. #define array_insert_sorted_with(self, compare, value) \ do { \ unsigned _index, _exists; \ @@ -114,10 +143,10 @@ extern "C" { if (!_exists) array_insert(self, _index, value); \ } while (0) -// Insert a given `value` into a sorted array, using integer comparisons of -// a given struct field (specified with a leading dot) to determine the order. -// -// See also `array_search_sorted_by`. +/// Insert a given `value` into a sorted array, using integer comparisons of +/// a given struct field (specified with a leading dot) to determine the order. +/// +/// See also `array_search_sorted_by`. #define array_insert_sorted_by(self, field, value) \ do { \ unsigned _index, _exists; \ @@ -127,11 +156,10 @@ extern "C" { // Private -typedef Array(void) VoidArray; - -#define array__elem_size(self) sizeof(*(self)->contents) +typedef Array(void) Array; -static inline void array__delete(VoidArray *self) { +/// This is not what you're looking for, see `array_delete`. +static inline void _array__delete(Array *self) { if (self->contents) { ts_free(self->contents); self->contents = NULL; @@ -140,7 +168,8 @@ static inline void array__delete(VoidArray *self) { } } -static inline void array__erase(VoidArray *self, size_t element_size, +/// This is not what you're looking for, see `array_erase`. +static inline void _array__erase(Array *self, size_t element_size, uint32_t index) { assert(index < self->size); char *contents = (char *)self->contents; @@ -149,7 +178,8 @@ static inline void array__erase(VoidArray *self, size_t element_size, self->size--; } -static inline void array__reserve(VoidArray *self, size_t element_size, uint32_t new_capacity) { +/// This is not what you're looking for, see `array_reserve`. +static inline void _array__reserve(Array *self, size_t element_size, uint32_t new_capacity) { if (new_capacity > self->capacity) { if (self->contents) { self->contents = ts_realloc(self->contents, new_capacity * element_size); @@ -160,29 +190,33 @@ static inline void array__reserve(VoidArray *self, size_t element_size, uint32_t } } -static inline void array__assign(VoidArray *self, const VoidArray *other, size_t element_size) { - array__reserve(self, element_size, other->size); +/// This is not what you're looking for, see `array_assign`. +static inline void _array__assign(Array *self, const Array *other, size_t element_size) { + _array__reserve(self, element_size, other->size); self->size = other->size; memcpy(self->contents, other->contents, self->size * element_size); } -static inline void array__swap(VoidArray *self, VoidArray *other) { - VoidArray swap = *other; +/// This is not what you're looking for, see `array_swap`. +static inline void _array__swap(Array *self, Array *other) { + Array swap = *other; *other = *self; *self = swap; } -static inline void array__grow(VoidArray *self, uint32_t count, size_t element_size) { +/// This is not what you're looking for, see `array_push` or `array_grow_by`. +static inline void _array__grow(Array *self, uint32_t count, size_t element_size) { uint32_t new_size = self->size + count; if (new_size > self->capacity) { uint32_t new_capacity = self->capacity * 2; if (new_capacity < 8) new_capacity = 8; if (new_capacity < new_size) new_capacity = new_size; - array__reserve(self, element_size, new_capacity); + _array__reserve(self, element_size, new_capacity); } } -static inline void array__splice(VoidArray *self, size_t element_size, +/// This is not what you're looking for, see `array_splice`. +static inline void _array__splice(Array *self, size_t element_size, uint32_t index, uint32_t old_count, uint32_t new_count, const void *elements) { uint32_t new_size = self->size + new_count - old_count; @@ -190,7 +224,7 @@ static inline void array__splice(VoidArray *self, size_t element_size, uint32_t new_end = index + new_count; assert(old_end <= self->size); - array__reserve(self, element_size, new_size); + _array__reserve(self, element_size, new_size); char *contents = (char *)self->contents; if (self->size > old_end) { @@ -218,8 +252,9 @@ static inline void array__splice(VoidArray *self, size_t element_size, self->size += new_count - old_count; } -// A binary search routine, based on Rust's `std::slice::binary_search_by`. -#define array__search_sorted(self, start, compare, suffix, needle, _index, _exists) \ +/// A binary search routine, based on Rust's `std::slice::binary_search_by`. +/// This is not what you're looking for, see `array_search_sorted_with` or `array_search_sorted_by`. +#define _array__search_sorted(self, start, compare, suffix, needle, _index, _exists) \ do { \ *(_index) = start; \ *(_exists) = false; \ @@ -238,9 +273,15 @@ static inline void array__splice(VoidArray *self, size_t element_size, else if (comparison < 0) *(_index) += 1; \ } while (0) -// Helper macro for the `_sorted_by` routines below. This takes the left (existing) -// parameter by reference in order to work with the generic sorting function above. -#define compare_int(a, b) ((int)*(a) - (int)(b)) +/// Helper macro for the `_sorted_by` routines below. This takes the left (existing) +/// parameter by reference in order to work with the generic sorting function above. +#define _compare_int(a, b) ((int)*(a) - (int)(b)) + +#ifdef _MSC_VER +#pragma warning(default : 4101) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif #ifdef __cplusplus } diff --git a/language.c b/language.c index 186edfa8..84b15c01 100644 --- a/language.c +++ b/language.c @@ -1,6 +1,21 @@ #include "./language.h" +#include "./wasm_store.h" +#include "api.h" #include +const TSLanguage *ts_language_copy(const TSLanguage *self) { + if (self && ts_language_is_wasm(self)) { + ts_wasm_language_retain(self); + } + return self; +} + +void ts_language_delete(const TSLanguage *self) { + if (self && ts_language_is_wasm(self)) { + ts_wasm_language_release(self); + } +} + uint32_t ts_language_symbol_count(const TSLanguage *self) { return self->symbol_count + self->alias_count; } diff --git a/language.h b/language.h index 4de9a918..4e2769b4 100644 --- a/language.h +++ b/language.h @@ -6,10 +6,13 @@ extern "C" { #endif #include "./subtree.h" -#include "parser.h" +#include "./parser.h" #define ts_builtin_sym_error_repeat (ts_builtin_sym_error - 1) +#define LANGUAGE_VERSION_WITH_PRIMARY_STATES 14 +#define LANGUAGE_VERSION_USABLE_VIA_WASM 13 + typedef struct { const TSParseAction *actions; uint32_t action_count; @@ -186,7 +189,7 @@ static inline bool ts_language_state_is_primary( const TSLanguage *self, TSStateId state ) { - if (self->version >= 14) { + if (self->version >= LANGUAGE_VERSION_WITH_PRIMARY_STATES) { return state == self->primary_state_ids[state]; } else { return true; diff --git a/lexer.h b/lexer.h index f79f6aa6..a8cc38f1 100644 --- a/lexer.h +++ b/lexer.h @@ -8,7 +8,7 @@ extern "C" { #include "./length.h" #include "./subtree.h" #include "api.h" -#include "parser.h" +#include "./parser.h" typedef struct { TSLexer data; diff --git a/node.c b/node.c index 546b9099..f9960213 100644 --- a/node.c +++ b/node.c @@ -439,7 +439,14 @@ const char *ts_node_grammar_type(TSNode self) { } char *ts_node_string(TSNode self) { - return ts_subtree_string(ts_node__subtree(self), self.tree->language, false); + TSSymbol alias_symbol = ts_node__alias(&self); + return ts_subtree_string( + ts_node__subtree(self), + alias_symbol, + ts_language_symbol_metadata(self.tree->language, alias_symbol).visible, + self.tree->language, + false + ); } bool ts_node_eq(TSNode self, TSNode other) { @@ -513,7 +520,7 @@ TSNode ts_node_parent(TSNode self) { ts_node_start_byte(child) > ts_node_start_byte(self) || child.id == self.id ) break; - if (iterator.position.bytes >= end_byte) { + if (iterator.position.bytes >= end_byte && ts_node_child_count(child) > 0) { node = child; if (ts_node__is_relevant(child, true)) { last_visible_node = node; diff --git a/parser.c b/parser.c index 8058bd60..a72983dd 100644 --- a/parser.c +++ b/parser.c @@ -1,8 +1,11 @@ +#define _POSIX_C_SOURCE 200112L + #include #include #include #include #include +#include #include "api.h" #include "./alloc.h" #include "./array.h" @@ -18,7 +21,7 @@ #include "./stack.h" #include "./subtree.h" #include "./tree.h" -#include "./wasm.h" +#include "./wasm_store.h" #define LOG(...) \ if (self->lexer.logger.log || self->dot_graph_file) { \ @@ -107,6 +110,7 @@ struct TSParser { Subtree old_tree; TSRangeArray included_range_differences; unsigned included_range_difference_index; + bool has_scanner_error; }; typedef struct { @@ -334,6 +338,22 @@ static bool ts_parser__better_version_exists( return false; } +static bool ts_parser__call_main_lex_fn(TSParser *self, TSLexMode lex_mode) { + if (ts_language_is_wasm(self->language)) { + return ts_wasm_store_call_lex_main(self->wasm_store, lex_mode.lex_state); + } else { + return self->language->lex_fn(&self->lexer.data, lex_mode.lex_state); + } +} + +static bool ts_parser__call_keyword_lex_fn(TSParser *self, TSLexMode lex_mode) { + if (ts_language_is_wasm(self->language)) { + return ts_wasm_store_call_lex_keyword(self->wasm_store, 0); + } else { + return self->language->keyword_lex_fn(&self->lexer.data, 0); + } +} + static void ts_parser__external_scanner_create( TSParser *self ) { @@ -342,6 +362,9 @@ static void ts_parser__external_scanner_create( self->external_scanner_payload = (void *)(uintptr_t)ts_wasm_store_call_scanner_create( self->wasm_store ); + if (ts_wasm_store_has_error(self->wasm_store)) { + self->has_scanner_error = true; + } } else if (self->language->external_scanner.create) { self->external_scanner_payload = self->language->external_scanner.create(); } @@ -351,21 +374,17 @@ static void ts_parser__external_scanner_create( static void ts_parser__external_scanner_destroy( TSParser *self ) { - if (self->language && self->external_scanner_payload) { - if (ts_language_is_wasm(self->language)) { - if (self->wasm_store) { - ts_wasm_store_call_scanner_destroy( - self->wasm_store, - (uintptr_t)self->external_scanner_payload - ); - } - } else if (self->language->external_scanner.destroy) { - self->language->external_scanner.destroy( - self->external_scanner_payload - ); - } - self->external_scanner_payload = NULL; + if ( + self->language && + self->external_scanner_payload && + self->language->external_scanner.destroy && + !ts_language_is_wasm(self->language) + ) { + self->language->external_scanner.destroy( + self->external_scanner_payload + ); } + self->external_scanner_payload = NULL; } static unsigned ts_parser__external_scanner_serialize( @@ -403,6 +422,9 @@ static void ts_parser__external_scanner_deserialize( data, length ); + if (ts_wasm_store_has_error(self->wasm_store)) { + self->has_scanner_error = true; + } } else { self->language->external_scanner.deserialize( self->external_scanner_payload, @@ -416,13 +438,16 @@ static bool ts_parser__external_scanner_scan( TSParser *self, TSStateId external_lex_state ) { - if (ts_language_is_wasm(self->language)) { - return ts_wasm_store_call_scanner_scan( + bool result = ts_wasm_store_call_scanner_scan( self->wasm_store, (uintptr_t)self->external_scanner_payload, external_lex_state * self->language->external_token_count ); + if (ts_wasm_store_has_error(self->wasm_store)) { + self->has_scanner_error = true; + } + return result; } else { const bool *valid_external_tokens = ts_language_enabled_external_tokens( self->language, @@ -511,6 +536,7 @@ static Subtree ts_parser__lex( ts_lexer_start(&self->lexer); ts_parser__external_scanner_deserialize(self, external_token); found_token = ts_parser__external_scanner_scan(self, lex_mode.external_lex_state); + if (self->has_scanner_error) return NULL_SUBTREE; ts_lexer_finish(&self->lexer, &lookahead_end_byte); if (found_token) { @@ -561,12 +587,7 @@ static Subtree ts_parser__lex( current_position.extent.column ); ts_lexer_start(&self->lexer); - found_token = false; - if (ts_language_is_wasm(self->language)) { - found_token = ts_wasm_store_call_lex_main(self->wasm_store, lex_mode.lex_state); - } else { - found_token = self->language->lex_fn(&self->lexer.data, lex_mode.lex_state); - } + found_token = ts_parser__call_main_lex_fn(self, lex_mode); ts_lexer_finish(&self->lexer, &lookahead_end_byte); if (found_token) break; @@ -624,11 +645,7 @@ static Subtree ts_parser__lex( ts_lexer_reset(&self->lexer, self->lexer.token_start_position); ts_lexer_start(&self->lexer); - if (ts_language_is_wasm(self->language)) { - is_keyword = ts_wasm_store_call_lex_keyword(self->wasm_store, 0); - } else { - is_keyword = self->language->keyword_lex_fn(&self->lexer.data, 0); - } + is_keyword = ts_parser__call_keyword_lex_fn(self, lex_mode); if ( is_keyword && @@ -818,14 +835,14 @@ static bool ts_parser__select_tree(TSParser *self, Subtree left, Subtree right) } if (ts_subtree_dynamic_precedence(right) > ts_subtree_dynamic_precedence(left)) { - LOG("select_higher_precedence symbol:%s, prec:%u, over_symbol:%s, other_prec:%u", + LOG("select_higher_precedence symbol:%s, prec:%" PRId32 ", over_symbol:%s, other_prec:%" PRId32, TREE_NAME(right), ts_subtree_dynamic_precedence(right), TREE_NAME(left), ts_subtree_dynamic_precedence(left)); return true; } if (ts_subtree_dynamic_precedence(left) > ts_subtree_dynamic_precedence(right)) { - LOG("select_higher_precedence symbol:%s, prec:%u, over_symbol:%s, other_prec:%u", + LOG("select_higher_precedence symbol:%s, prec:%" PRId32 ", over_symbol:%s, other_prec:%" PRId32, TREE_NAME(left), ts_subtree_dynamic_precedence(left), TREE_NAME(right), ts_subtree_dynamic_precedence(right)); return false; @@ -1478,7 +1495,7 @@ static void ts_parser__handle_error( ts_stack_record_summary(self->stack, version, MAX_SUMMARY_DEPTH); // Begin recovery with the current lookahead node, rather than waiting for the - // next turn of the parse loop. This ensures that the tree accounts for the the + // next turn of the parse loop. This ensures that the tree accounts for the // current lookahead token's "lookahead bytes" value, which describes how far // the lexer needed to look ahead beyond the content of the token in order to // recognize it. @@ -1525,6 +1542,7 @@ static bool ts_parser__advance( if (needs_lex) { needs_lex = false; lookahead = ts_parser__lex(self, version, state); + if (self->has_scanner_error) return false; if (lookahead.ptr) { ts_parser__set_cached_token(self, position, last_external_token, lookahead); @@ -1768,7 +1786,7 @@ static unsigned ts_parser__condense_stack(TSParser *self) { } } - // Enfore a hard upper bound on the number of stack versions by + // Enforce a hard upper bound on the number of stack versions by // discarding the least promising versions. while (ts_stack_version_count(self->stack) > MAX_VERSION_COUNT) { ts_stack_remove_version(self->stack, MAX_VERSION_COUNT); @@ -1809,6 +1827,7 @@ static unsigned ts_parser__condense_stack(TSParser *self) { static bool ts_parser_has_outstanding_parse(TSParser *self) { return ( + self->external_scanner_payload || ts_stack_state(self->stack, 0) != 1 || ts_stack_node_count_since_error(self->stack, 0) != 0 ); @@ -1828,6 +1847,9 @@ TSParser *ts_parser_new(void) { self->dot_graph_file = NULL; self->cancellation_flag = NULL; self->timeout_duration = 0; + self->language = NULL; + self->has_scanner_error = false; + self->external_scanner_payload = NULL; self->end_clock = clock_null(); self->operation_count = 0; self->old_tree = NULL_SUBTREE; @@ -1868,7 +1890,8 @@ const TSLanguage *ts_parser_language(const TSParser *self) { } bool ts_parser_set_language(TSParser *self, const TSLanguage *language) { - ts_parser__external_scanner_destroy(self); + ts_parser_reset(self); + ts_language_delete(self->language); self->language = NULL; if (language) { @@ -1885,9 +1908,7 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) { } } - self->language = language; - ts_parser__external_scanner_create(self); - ts_parser_reset(self); + self->language = ts_language_copy(language); return true; } @@ -1944,8 +1965,9 @@ const TSRange *ts_parser_included_ranges(const TSParser *self, uint32_t *count) } void ts_parser_reset(TSParser *self) { - if (self->language && self->language->external_scanner.deserialize) { - self->language->external_scanner.deserialize(self->external_scanner_payload, NULL, 0); + ts_parser__external_scanner_destroy(self); + if (self->wasm_store) { + ts_wasm_store_reset(self->wasm_store); } if (self->old_tree.ptr) { @@ -1962,6 +1984,7 @@ void ts_parser_reset(TSParser *self) { self->finished_tree = NULL_SUBTREE; } self->accept_count = 0; + self->has_scanner_error = false; } TSTree *ts_parser_parse( @@ -1969,41 +1992,43 @@ TSTree *ts_parser_parse( const TSTree *old_tree, TSInput input ) { + TSTree *result = NULL; if (!self->language || !input.read) return NULL; if (ts_language_is_wasm(self->language)) { - if (self->wasm_store) { - ts_wasm_store_start(self->wasm_store, &self->lexer.data, self->language); - } else { - return NULL; - } + if (!self->wasm_store) return NULL; + ts_wasm_store_start(self->wasm_store, &self->lexer.data, self->language); } ts_lexer_set_input(&self->lexer, input); - array_clear(&self->included_range_differences); self->included_range_difference_index = 0; if (ts_parser_has_outstanding_parse(self)) { LOG("resume_parsing"); - } else if (old_tree) { - ts_subtree_retain(old_tree->root); - self->old_tree = old_tree->root; - ts_range_array_get_changed_ranges( - old_tree->included_ranges, old_tree->included_range_count, - self->lexer.included_ranges, self->lexer.included_range_count, - &self->included_range_differences - ); - reusable_node_reset(&self->reusable_node, old_tree->root); - LOG("parse_after_edit"); - LOG_TREE(self->old_tree); - for (unsigned i = 0; i < self->included_range_differences.size; i++) { - TSRange *range = &self->included_range_differences.contents[i]; - LOG("different_included_range %u - %u", range->start_byte, range->end_byte); - } } else { - reusable_node_clear(&self->reusable_node); - LOG("new_parse"); + ts_parser__external_scanner_create(self); + if (self->has_scanner_error) goto exit; + + if (old_tree) { + ts_subtree_retain(old_tree->root); + self->old_tree = old_tree->root; + ts_range_array_get_changed_ranges( + old_tree->included_ranges, old_tree->included_range_count, + self->lexer.included_ranges, self->lexer.included_range_count, + &self->included_range_differences + ); + reusable_node_reset(&self->reusable_node, old_tree->root); + LOG("parse_after_edit"); + LOG_TREE(self->old_tree); + for (unsigned i = 0; i < self->included_range_differences.size; i++) { + TSRange *range = &self->included_range_differences.contents[i]; + LOG("different_included_range %u - %u", range->start_byte, range->end_byte); + } + } else { + reusable_node_clear(&self->reusable_node); + LOG("new_parse"); + } } self->operation_count = 0; @@ -2024,7 +2049,7 @@ TSTree *ts_parser_parse( bool allow_node_reuse = version_count == 1; while (ts_stack_is_active(self->stack, version)) { LOG( - "process version:%d, version_count:%u, state:%d, row:%u, col:%u", + "process version:%u, version_count:%u, state:%d, row:%u, col:%u", version, ts_stack_version_count(self->stack), ts_stack_state(self->stack, version), @@ -2032,7 +2057,11 @@ TSTree *ts_parser_parse( ts_stack_position(self->stack, version).extent.column ); - if (!ts_parser__advance(self, version, allow_node_reuse)) return NULL; + if (!ts_parser__advance(self, version, allow_node_reuse)) { + if (self->has_scanner_error) goto exit; + return NULL; + } + LOG_STACK(); position = ts_stack_position(self->stack, version).bytes; @@ -2071,13 +2100,15 @@ TSTree *ts_parser_parse( LOG("done"); LOG_TREE(self->finished_tree); - TSTree *result = ts_tree_new( + result = ts_tree_new( self->finished_tree, self->language, self->lexer.included_ranges, self->lexer.included_range_count ); self->finished_tree = NULL_SUBTREE; + +exit: ts_parser_reset(self); return result; } diff --git a/parser.h b/parser.h index 17b4fde9..17f0e94b 100644 --- a/parser.h +++ b/parser.h @@ -86,6 +86,11 @@ typedef union { } entry; } TSParseActionEntry; +typedef struct { + int32_t start; + int32_t end; +} TSCharacterRange; + struct TSLanguage { uint32_t version; uint32_t symbol_count; @@ -125,6 +130,24 @@ struct TSLanguage { const TSStateId *primary_state_ids; }; +static inline bool set_contains(TSCharacterRange *ranges, uint32_t len, int32_t lookahead) { + uint32_t index = 0; + uint32_t size = len - index; + while (size > 1) { + uint32_t half_size = size / 2; + uint32_t mid_index = index + half_size; + TSCharacterRange *range = &ranges[mid_index]; + if (lookahead >= range->start && lookahead <= range->end) { + return true; + } else if (lookahead > range->end) { + index = mid_index; + } + size -= half_size; + } + TSCharacterRange *range = &ranges[index]; + return (lookahead >= range->start && lookahead <= range->end); +} + /* * Lexer Macros */ @@ -154,6 +177,17 @@ struct TSLanguage { goto next_state; \ } +#define ADVANCE_MAP(...) \ + { \ + static const uint16_t map[] = { __VA_ARGS__ }; \ + for (uint32_t i = 0; i < sizeof(map) / sizeof(map[0]); i += 2) { \ + if (map[i] == lookahead) { \ + state = map[i + 1]; \ + goto next_state; \ + } \ + } \ + } + #define SKIP(state_value) \ { \ skip = true; \ @@ -203,14 +237,15 @@ struct TSLanguage { } \ }} -#define REDUCE(symbol_val, child_count_val, ...) \ - {{ \ - .reduce = { \ - .type = TSParseActionTypeReduce, \ - .symbol = symbol_val, \ - .child_count = child_count_val, \ - __VA_ARGS__ \ - }, \ +#define REDUCE(symbol_name, children, precedence, prod_id) \ + {{ \ + .reduce = { \ + .type = TSParseActionTypeReduce, \ + .symbol = symbol_name, \ + .child_count = children, \ + .dynamic_precedence = precedence, \ + .production_id = prod_id \ + }, \ }} #define RECOVER() \ diff --git a/query.c b/query.c index 34279db3..1b6e04b6 100644 --- a/query.c +++ b/query.c @@ -42,7 +42,7 @@ typedef struct { * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. * - `negated_field_list_id` - An id representing a set of fields that must - * that must not be present on a node matching this step. + * not be present on a node matching this step. * * Steps have some additional fields in order to handle the `.` (or "anchor") operator, * which forbids additional child nodes: @@ -1030,7 +1030,7 @@ static inline void analysis_state_set__delete(AnalysisStateSet *self) { * QueryAnalyzer ****************/ -static inline QueryAnalysis query_analysis__new() { +static inline QueryAnalysis query_analysis__new(void) { return (QueryAnalysis) { .states = array_new(), .next_states = array_new(), @@ -2312,15 +2312,8 @@ static TSQueryError ts_query__parse_pattern( stream_scan_identifier(stream); uint32_t length = (uint32_t)(stream->input - node_name); - // TODO - remove. - // For temporary backward compatibility, handle predicates without the leading '#' sign. - if (length > 0 && (node_name[length - 1] == '!' || node_name[length - 1] == '?')) { - stream_reset(stream, node_name); - return ts_query__parse_predicate(self, stream); - } - // Parse the wildcard symbol - else if (length == 1 && node_name[0] == '_') { + if (length == 1 && node_name[0] == '_') { symbol = WILDCARD_SYMBOL; } @@ -2650,7 +2643,6 @@ static TSQueryError ts_query__parse_pattern( step->alternative_index < self->steps.size ) { step_index = step->alternative_index; - step = &self->steps.contents[step_index]; } else { break; } @@ -2698,7 +2690,7 @@ TSQuery *ts_query_new( .negated_fields = array_new(), .repeat_symbols_with_rootless_patterns = array_new(), .wildcard_root_pattern_count = 0, - .language = language, + .language = ts_language_copy(language), }; array_push(&self->negated_fields, 0); @@ -2812,6 +2804,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->string_buffer); array_delete(&self->negated_fields); array_delete(&self->repeat_symbols_with_rootless_patterns); + ts_language_delete(self->language); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); for (uint32_t index = 0; index < self->capture_quantifiers.size; index++) { @@ -3848,7 +3841,7 @@ static inline bool ts_query_cursor__advance( continue; } - // Enfore the longest-match criteria. When a query pattern contains optional or + // Enforce the longest-match criteria. When a query pattern contains optional or // repeated nodes, this is necessary to avoid multiple redundant states, where // one state has a strict subset of another state's captures. bool did_remove = false; diff --git a/stack.c b/stack.c index 34846352..98d8c561 100644 --- a/stack.c +++ b/stack.c @@ -5,6 +5,7 @@ #include "./stack.h" #include "./length.h" #include +#include #include #define MAX_LINK_COUNT 8 @@ -12,9 +13,9 @@ #define MAX_ITERATOR_COUNT 64 #if defined _WIN32 && !defined __GNUC__ -#define inline __forceinline +#define forceinline __forceinline #else -#define inline static inline __attribute__((always_inline)) +#define forceinline static inline __attribute__((always_inline)) #endif typedef struct StackNode StackNode; @@ -227,7 +228,8 @@ static void stack_node_add_link( // If the previous nodes are mergeable, merge them recursively. if ( existing_link->node->state == link.node->state && - existing_link->node->position.bytes == link.node->position.bytes + existing_link->node->position.bytes == link.node->position.bytes && + existing_link->node->error_cost == link.node->error_cost ) { for (int j = 0; j < link.node->link_count; j++) { stack_node_add_link(existing_link->node, link.node->links[j], subtree_pool); @@ -509,7 +511,7 @@ void ts_stack_push( head->node = new_node; } -inline StackAction pop_count_callback(void *payload, const StackIterator *iterator) { +forceinline StackAction pop_count_callback(void *payload, const StackIterator *iterator) { unsigned *goal_subtree_count = payload; if (iterator->subtree_count == *goal_subtree_count) { return StackActionPop | StackActionStop; @@ -522,7 +524,7 @@ StackSliceArray ts_stack_pop_count(Stack *self, StackVersion version, uint32_t c return stack__iter(self, version, pop_count_callback, &count, (int)count); } -inline StackAction pop_pending_callback(void *payload, const StackIterator *iterator) { +forceinline StackAction pop_pending_callback(void *payload, const StackIterator *iterator) { (void)payload; if (iterator->subtree_count >= 1) { if (iterator->is_pending) { @@ -544,7 +546,7 @@ StackSliceArray ts_stack_pop_pending(Stack *self, StackVersion version) { return pop; } -inline StackAction pop_error_callback(void *payload, const StackIterator *iterator) { +forceinline StackAction pop_error_callback(void *payload, const StackIterator *iterator) { if (iterator->subtrees.size > 0) { bool *found_error = payload; if (!*found_error && ts_subtree_is_error(iterator->subtrees.contents[0])) { @@ -575,7 +577,7 @@ SubtreeArray ts_stack_pop_error(Stack *self, StackVersion version) { return (SubtreeArray) {.size = 0}; } -inline StackAction pop_all_callback(void *payload, const StackIterator *iterator) { +forceinline StackAction pop_all_callback(void *payload, const StackIterator *iterator) { (void)payload; return iterator->node->link_count == 0 ? StackActionPop : StackActionNone; } @@ -589,7 +591,7 @@ typedef struct { unsigned max_depth; } SummarizeStackSession; -inline StackAction summarize_stack_callback(void *payload, const StackIterator *iterator) { +forceinline StackAction summarize_stack_callback(void *payload, const StackIterator *iterator) { SummarizeStackSession *session = payload; TSStateId state = iterator->node->state; unsigned depth = iterator->subtree_count; @@ -866,7 +868,7 @@ bool ts_stack_print_dot_graph(Stack *self, const TSLanguage *language, FILE *f) fprintf(f, "\""); fprintf( f, - "labeltooltip=\"error_cost: %u\ndynamic_precedence: %u\"", + "labeltooltip=\"error_cost: %u\ndynamic_precedence: %" PRId32 "\"", ts_subtree_error_cost(link.subtree), ts_subtree_dynamic_precedence(link.subtree) ); @@ -894,4 +896,4 @@ bool ts_stack_print_dot_graph(Stack *self, const TSLanguage *language, FILE *f) return true; } -#undef inline +#undef forceinline diff --git a/subtree.c b/subtree.c index cad48df4..4524e182 100644 --- a/subtree.c +++ b/subtree.c @@ -629,9 +629,9 @@ int ts_subtree_compare(Subtree left, Subtree right, SubtreePool *pool) { int result = 0; if (ts_subtree_symbol(left) < ts_subtree_symbol(right)) result = -1; - if (ts_subtree_symbol(right) < ts_subtree_symbol(left)) result = 1; - if (ts_subtree_child_count(left) < ts_subtree_child_count(right)) result = -1; - if (ts_subtree_child_count(right) < ts_subtree_child_count(left)) result = 1; + else if (ts_subtree_symbol(right) < ts_subtree_symbol(left)) result = 1; + else if (ts_subtree_child_count(left) < ts_subtree_child_count(right)) result = -1; + else if (ts_subtree_child_count(right) < ts_subtree_child_count(left)) result = 1; if (result != 0) { array_clear(&pool->tree_stack); return result; @@ -890,9 +890,15 @@ static size_t ts_subtree__write_to_string( } } } else if (is_root) { - TSSymbol symbol = ts_subtree_symbol(self); + TSSymbol symbol = alias_symbol ? alias_symbol : ts_subtree_symbol(self); const char *symbol_name = ts_language_symbol_name(language, symbol); - cursor += snprintf(*writer, limit, "(\"%s\")", symbol_name); + if (ts_subtree_child_count(self) > 0) { + cursor += snprintf(*writer, limit, "(%s", symbol_name); + } else if (ts_subtree_named(self)) { + cursor += snprintf(*writer, limit, "(%s)", symbol_name); + } else { + cursor += snprintf(*writer, limit, "(\"%s\")", symbol_name); + } } if (ts_subtree_child_count(self)) { @@ -947,6 +953,8 @@ static size_t ts_subtree__write_to_string( char *ts_subtree_string( Subtree self, + TSSymbol alias_symbol, + bool alias_is_named, const TSLanguage *language, bool include_all ) { @@ -954,13 +962,13 @@ char *ts_subtree_string( size_t size = ts_subtree__write_to_string( self, scratch_string, 1, language, include_all, - 0, false, ROOT_FIELD + alias_symbol, alias_is_named, ROOT_FIELD ) + 1; char *result = ts_malloc(size * sizeof(char)); ts_subtree__write_to_string( self, result, size, language, include_all, - 0, false, ROOT_FIELD + alias_symbol, alias_is_named, ROOT_FIELD ); return result; } @@ -997,7 +1005,7 @@ void ts_subtree__print_dot_graph(const Subtree *self, uint32_t start_offset, ts_subtree_lookahead_bytes(*self) ); - if (ts_subtree_is_error(*self) && ts_subtree_child_count(*self) == 0) { + if (ts_subtree_is_error(*self) && ts_subtree_child_count(*self) == 0 && self->ptr->lookahead_char != 0) { fprintf(f, "\ncharacter: '%c'", self->ptr->lookahead_char); } diff --git a/subtree.h b/subtree.h index 13eaf4de..0b3062e9 100644 --- a/subtree.h +++ b/subtree.h @@ -13,7 +13,7 @@ extern "C" { #include "./error_costs.h" #include "./host.h" #include "api.h" -#include "parser.h" +#include "./parser.h" #define TS_TREE_STATE_NONE USHRT_MAX #define NULL_SUBTREE ((Subtree) {.ptr = NULL}) @@ -206,7 +206,7 @@ void ts_subtree_summarize(MutableSubtree, const Subtree *, uint32_t, const TSLan void ts_subtree_summarize_children(MutableSubtree, const TSLanguage *); void ts_subtree_balance(Subtree, SubtreePool *, const TSLanguage *); Subtree ts_subtree_edit(Subtree, const TSInputEdit *edit, SubtreePool *); -char *ts_subtree_string(Subtree, const TSLanguage *, bool include_all); +char *ts_subtree_string(Subtree, TSSymbol, bool, const TSLanguage *, bool include_all); void ts_subtree_print_dot_graph(Subtree, const TSLanguage *, FILE *); Subtree ts_subtree_last_external_token(Subtree); const ExternalScannerState *ts_subtree_external_scanner_state(Subtree self); diff --git a/tree.c b/tree.c index 90e87789..1cea1794 100644 --- a/tree.c +++ b/tree.c @@ -1,3 +1,5 @@ +#define _POSIX_C_SOURCE 200112L + #include "api.h" #include "./array.h" #include "./get_changed_ranges.h" @@ -12,7 +14,7 @@ TSTree *ts_tree_new( ) { TSTree *result = ts_malloc(sizeof(TSTree)); result->root = root; - result->language = language; + result->language = ts_language_copy(language); result->included_ranges = ts_calloc(included_range_count, sizeof(TSRange)); memcpy(result->included_ranges, included_ranges, included_range_count * sizeof(TSRange)); result->included_range_count = included_range_count; @@ -30,6 +32,7 @@ void ts_tree_delete(TSTree *self) { SubtreePool pool = ts_subtree_pool_new(0); ts_subtree_release(&pool, self->root); ts_subtree_pool_delete(&pool); + ts_language_delete(self->language); ts_free(self->included_ranges); ts_free(self); } @@ -99,8 +102,8 @@ TSRange *ts_tree_included_ranges(const TSTree *self, uint32_t *length) { } TSRange *ts_tree_get_changed_ranges(const TSTree *old_tree, const TSTree *new_tree, uint32_t *length) { - TreeCursor cursor1 = {NULL, array_new()}; - TreeCursor cursor2 = {NULL, array_new()}; + TreeCursor cursor1 = {NULL, array_new(), 0}; + TreeCursor cursor2 = {NULL, array_new(), 0}; ts_tree_cursor_init(&cursor1, ts_tree_root_node(old_tree)); ts_tree_cursor_init(&cursor2, ts_tree_root_node(new_tree)); @@ -125,17 +128,36 @@ TSRange *ts_tree_get_changed_ranges(const TSTree *old_tree, const TSTree *new_tr #ifdef _WIN32 +#include +#include + +int _ts_dup(HANDLE handle) { + HANDLE dup_handle; + if (!DuplicateHandle( + GetCurrentProcess(), handle, + GetCurrentProcess(), &dup_handle, + 0, FALSE, DUPLICATE_SAME_ACCESS + )) return -1; + + return _open_osfhandle((intptr_t)dup_handle, 0); +} + void ts_tree_print_dot_graph(const TSTree *self, int fd) { - (void)self; - (void)fd; + FILE *file = _fdopen(_ts_dup((HANDLE)_get_osfhandle(fd)), "a"); + ts_subtree_print_dot_graph(self->root, self->language, file); + fclose(file); } #else #include +int _ts_dup(int file_descriptor) { + return dup(file_descriptor); +} + void ts_tree_print_dot_graph(const TSTree *self, int file_descriptor) { - FILE *file = fdopen(dup(file_descriptor), "a"); + FILE *file = fdopen(_ts_dup(file_descriptor), "a"); ts_subtree_print_dot_graph(self->root, self->language, file); fclose(file); } diff --git a/tree_cursor.c b/tree_cursor.c index cfe0e1cb..c1a3d8a4 100644 --- a/tree_cursor.c +++ b/tree_cursor.c @@ -151,7 +151,7 @@ static inline bool ts_tree_cursor_child_iterator_previous( // TSTreeCursor - lifecycle TSTreeCursor ts_tree_cursor_new(TSNode node) { - TSTreeCursor self = {NULL, NULL, {0, 0}}; + TSTreeCursor self = {NULL, NULL, {0, 0, 0}}; ts_tree_cursor_init((TreeCursor *)&self, node); return self; } @@ -162,6 +162,7 @@ void ts_tree_cursor_reset(TSTreeCursor *_self, TSNode node) { void ts_tree_cursor_init(TreeCursor *self, TSNode node) { self->tree = node.tree; + self->root_alias_symbol = node.context[3]; array_clear(&self->stack); array_push(&self->stack, ((TreeCursorEntry) { .subtree = (const Subtree *)node.id, @@ -221,7 +222,7 @@ TreeCursorStep ts_tree_cursor_goto_last_child_internal(TSTreeCursor *_self) { CursorChildIterator iterator = ts_tree_cursor_iterate_children(self); if (!iterator.parent.ptr || iterator.parent.ptr->child_count == 0) return TreeCursorStepNone; - TreeCursorEntry last_entry; + TreeCursorEntry last_entry = {0}; TreeCursorStep last_step = TreeCursorStepNone; while (ts_tree_cursor_child_iterator_next(&iterator, &entry, &visible)) { if (visible) { @@ -362,7 +363,6 @@ TreeCursorStep ts_tree_cursor_goto_previous_sibling_internal(TSTreeCursor *_self TreeCursor *self = (TreeCursor *)_self; // for that, save current position before traversing - Length position = array_back(&self->stack)->position; TreeCursorStep step = ts_tree_cursor_goto_sibling_internal( _self, ts_tree_cursor_child_iterator_previous); if (step == TreeCursorStepNone) @@ -374,7 +374,7 @@ TreeCursorStep ts_tree_cursor_goto_previous_sibling_internal(TSTreeCursor *_self // restore position from the parent node const TreeCursorEntry *parent = &self->stack.contents[self->stack.size - 2]; - position = parent->position; + Length position = parent->position; uint32_t child_index = array_back(&self->stack)->child_index; const Subtree *children = ts_subtree_children((*(parent->subtree))); @@ -475,7 +475,7 @@ uint32_t ts_tree_cursor_current_descendant_index(const TSTreeCursor *_self) { TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) { const TreeCursor *self = (const TreeCursor *)_self; TreeCursorEntry *last_entry = array_back(&self->stack); - TSSymbol alias_symbol = 0; + TSSymbol alias_symbol = self->root_alias_symbol; if (self->stack.size > 1 && !ts_subtree_extra(*last_entry->subtree)) { TreeCursorEntry *parent_entry = &self->stack.contents[self->stack.size - 2]; alias_symbol = ts_language_alias_at( @@ -698,6 +698,7 @@ TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *_cursor) { TSTreeCursor res = {NULL, NULL, {0, 0}}; TreeCursor *copy = (TreeCursor *)&res; copy->tree = cursor->tree; + copy->root_alias_symbol = cursor->root_alias_symbol; array_init(©->stack); array_push_all(©->stack, &cursor->stack); return res; @@ -707,6 +708,7 @@ void ts_tree_cursor_reset_to(TSTreeCursor *_dst, const TSTreeCursor *_src) { const TreeCursor *cursor = (const TreeCursor *)_src; TreeCursor *copy = (TreeCursor *)_dst; copy->tree = cursor->tree; + copy->root_alias_symbol = cursor->root_alias_symbol; array_clear(©->stack); array_push_all(©->stack, &cursor->stack); } diff --git a/tree_cursor.h b/tree_cursor.h index 6d4c688b..96a386df 100644 --- a/tree_cursor.h +++ b/tree_cursor.h @@ -14,6 +14,7 @@ typedef struct { typedef struct { const TSTree *tree; Array(TreeCursorEntry) stack; + TSSymbol root_alias_symbol; } TreeCursor; typedef enum { diff --git a/wasm.c b/wasm_store.c similarity index 69% rename from wasm.c rename to wasm_store.c index 02de0c70..7860c922 100644 --- a/wasm.c +++ b/wasm_store.c @@ -1,5 +1,5 @@ #include "api.h" -#include "parser.h" +#include "./parser.h" #include #ifdef TREE_SITTER_FEATURE_WASM @@ -10,54 +10,19 @@ #include "./alloc.h" #include "./array.h" #include "./atomic.h" +#include "./language.h" #include "./lexer.h" -#include "./wasm.h" -#include "./lexer.h" +#include "./wasm_store.h" #include "./wasm/wasm-stdlib.h" +#define array_len(a) (sizeof(a) / sizeof(a[0])) + // The following symbols from the C and C++ standard libraries are available // for external scanners to use. -#define STDLIB_SYMBOL_COUNT 34 -const char *STDLIB_SYMBOLS[STDLIB_SYMBOL_COUNT] = { - "_ZNKSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE4copyEPcmm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE17__assign_externalEPKcm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE17__assign_no_aliasILb0EEERS5_PKcm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE17__assign_no_aliasILb1EEERS5_PKcm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE25__init_copy_ctor_externalEPKcm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE6__initEPKcm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE7reserveEm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE9__grow_byEmmmmmm", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE9push_backEc", - "_ZNSt3__212basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEED2Ev", - "_ZNSt3__212basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEE9push_backEw", - "_ZNSt3__212basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEED2Ev", - "_ZdlPv", - "_Znwm", - "calloc", - "free", - "iswalnum", - "iswalpha", - "iswdigit", - "iswlower", - "iswspace", - "iswupper", - "malloc", - "memchr", - "memcmp", - "memcpy", - "memmove", - "memset", - "realloc", - "strcmp", - "strlen", - "strncpy", - "towlower", - "towupper" +const char *STDLIB_SYMBOLS[] = { + #include "./stdlib-symbols.txt" }; -#define BUILTIN_SYMBOL_COUNT 10 -#define MAX_IMPORT_COUNT (BUILTIN_SYMBOL_COUNT + STDLIB_SYMBOL_COUNT) - // The contents of the `dylink.0` custom section of a wasm module, // as specified by the current WebAssembly dynamic linking ABI proposal. typedef struct { @@ -67,13 +32,22 @@ typedef struct { uint32_t table_align; } WasmDylinkInfo; +// WasmLanguageId - A pointer used to identify a language. This language id is +// reference-counted, so that its ownership can be shared between the language +// itself and the instances of the language that are held in wasm stores. +typedef struct { + volatile uint32_t ref_count; + volatile uint32_t is_language_deleted; +} WasmLanguageId; + // LanguageWasmModule - Additional data associated with a wasm-backed // `TSLanguage`. This data is read-only and does not reference a particular // wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to // this is stored on the language itself. typedef struct { + volatile uint32_t ref_count; + WasmLanguageId *language_id; wasmtime_module_t *module; - uint32_t language_id; const char *name; char *symbol_name_buffer; char *field_name_buffer; @@ -84,7 +58,7 @@ typedef struct { // a `TSLanguage` in a particular wasm store. The wasm store holds one of // these structs for each language that it has instantiated. typedef struct { - uint32_t language_id; + WasmLanguageId *language_id; wasmtime_instance_t instance; int32_t external_states_address; int32_t lex_main_fn_index; @@ -96,6 +70,18 @@ typedef struct { int32_t scanner_scan_fn_index; } LanguageWasmInstance; +typedef struct { + uint32_t reset_heap; + uint32_t proc_exit; + uint32_t abort; + uint32_t assert_fail; + uint32_t notify_memory_growth; + uint32_t debug_message; + uint32_t at_exit; + uint32_t args_get; + uint32_t args_sizes_get; +} BuiltinFunctionIndices; + // TSWasmStore - A struct that allows a given `Parser` to use wasm-backed // languages. This struct is mutable, and can only be used by one parser at a // time. @@ -108,11 +94,14 @@ struct TSWasmStore { LanguageWasmInstance *current_instance; Array(LanguageWasmInstance) language_instances; uint32_t current_memory_offset; - uint32_t current_memory_size; uint32_t current_function_table_offset; - uint16_t fn_indices[STDLIB_SYMBOL_COUNT]; + uint32_t *stdlib_fn_indices; + BuiltinFunctionIndices builtin_fn_indices; + wasmtime_global_t stack_pointer_global; wasm_globaltype_t *const_i32_type; - wasm_globaltype_t *var_i32_type; + bool has_error; + uint32_t lexer_address; + uint32_t serialization_buffer_address; }; typedef Array(char) StringData; @@ -173,29 +162,8 @@ typedef struct { static volatile uint32_t NEXT_LANGUAGE_ID; // Linear memory layout: -// [ <-- stack | built-in data | heap --> | static data ] -#define STACK_SIZE (64 * 1024) -#define HEAP_SIZE (1024 * 1024) -#define INITIAL_MEMORY_SIZE (4 * 1024 * 1024 / MEMORY_PAGE_SIZE) -#define MAX_MEMORY_SIZE 32768 -#define SERIALIZATION_BUFFER_ADDRESS (STACK_SIZE) -#define LEXER_ADDRESS (SERIALIZATION_BUFFER_ADDRESS + TREE_SITTER_SERIALIZATION_BUFFER_SIZE) -#define HEAP_START_ADDRESS (LEXER_ADDRESS + sizeof(LexerInWasmMemory)) -#define DATA_START_ADDRESS (HEAP_START_ADDRESS + HEAP_SIZE) - -enum FunctionIx { - NULL_IX = 0, - PROC_EXIT_IX, - ABORT_IX, - ASSERT_FAIL_IX, - NOTIFY_MEMORY_GROWTH_IX, - AT_EXIT_IX, - LEXER_ADVANCE_IX, - LEXER_MARK_END_IX, - LEXER_GET_COLUMN_IX, - LEXER_IS_AT_INCLUDED_RANGE_START_IX, - LEXER_EOF_IX, -}; +// [ <-- stack | stdlib statics | lexer | serialization_buffer | language statics --> | heap --> ] +#define MAX_MEMORY_SIZE (128 * 1024 * 1024 / MEMORY_PAGE_SIZE) /************************ * WasmDylinkMemoryInfo @@ -273,27 +241,32 @@ static bool wasm_dylink_info__parse( * Native callbacks exposed to wasm modules *******************************************/ - static wasm_trap_t *callback__exit( + static wasm_trap_t *callback__abort( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, size_t args_and_results_len ) { - fprintf(stderr, "wasm module called exit"); - abort(); + return wasmtime_trap_new("wasm module called abort", 24); } -static wasm_trap_t *callback__notify_memory_growth( +static wasm_trap_t *callback__debug_message( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, size_t args_and_results_len ) { - fprintf(stderr, "wasm module called exit"); - abort(); + wasmtime_context_t *context = wasmtime_caller_context(caller); + TSWasmStore *store = env; + assert(args_and_results_len == 2); + uint32_t string_address = args_and_results[0].i32; + uint32_t value = args_and_results[1].i32; + uint8_t *memory = wasmtime_memory_data(context, &store->memory); + printf("DEBUG: %s %u\n", &memory[string_address], value); + return NULL; } -static wasm_trap_t *callback__at_exit( +static wasm_trap_t *callback__noop( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, @@ -317,7 +290,7 @@ static wasm_trap_t *callback__lexer_advance( lexer->advance(lexer, skip); uint8_t *memory = wasmtime_memory_data(context, &store->memory); - memcpy(&memory[LEXER_ADDRESS], &lexer->lookahead, sizeof(lexer->lookahead)); + memcpy(&memory[store->lexer_address], &lexer->lookahead, sizeof(lexer->lookahead)); return NULL; } @@ -373,12 +346,11 @@ static wasm_trap_t *callback__lexer_eof( } typedef struct { + uint32_t *storage_location; wasmtime_func_unchecked_callback_t callback; wasm_functype_t *type; } FunctionDefinition; -#define array_len(a) (sizeof(a) / sizeof(a[0])) - static void *copy(const void *data, size_t size) { void *result = ts_malloc(size); memcpy(result, data, size); @@ -453,17 +425,6 @@ static inline wasm_functype_t* wasm_functype_new_4_0( return wasm_functype_new(¶ms, &results); } -static wasmtime_extern_t get_builtin_func_extern( - wasmtime_context_t *context, - wasmtime_table_t *table, - unsigned index -) { - wasmtime_val_t val; - bool exists = wasmtime_table_get(context, table, index, &val); - assert(exists); - return (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = val.of.funcref}; -} - #define format(output, ...) \ do { \ size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \ @@ -471,6 +432,37 @@ static wasmtime_extern_t get_builtin_func_extern( snprintf(*output, message_length + 1, __VA_ARGS__); \ } while (0) +WasmLanguageId *language_id_new() { + WasmLanguageId *self = ts_malloc(sizeof(WasmLanguageId)); + self->is_language_deleted = false; + self->ref_count = 1; + return self; +} + +WasmLanguageId *language_id_clone(WasmLanguageId *self) { + atomic_inc(&self->ref_count); + return self; +} + +void language_id_delete(WasmLanguageId *self) { + if (atomic_dec(&self->ref_count) == 0) { + ts_free(self); + } +} + +static wasmtime_extern_t get_builtin_extern( + wasmtime_table_t *table, + unsigned index +) { + return (wasmtime_extern_t) { + .kind = WASMTIME_EXTERN_FUNC, + .of.func = (wasmtime_func_t) { + .store_id = table->store_id, + .index = index + } + }; +} + static bool ts_wasm_store__provide_builtin_import( TSWasmStore *self, const wasm_name_t *import_name, @@ -492,18 +484,8 @@ static bool ts_wasm_store__provide_builtin_import( error = wasmtime_global_new(context, self->const_i32_type, &value, &global); assert(!error); *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global}; - } else if (name_eq(import_name, "__heap_base")) { - wasmtime_val_t value = WASM_I32_VAL(HEAP_START_ADDRESS); - wasmtime_global_t global; - error = wasmtime_global_new(context, self->var_i32_type, &value, &global); - assert(!error); - *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global}; } else if (name_eq(import_name, "__stack_pointer")) { - wasmtime_val_t value = WASM_I32_VAL(STACK_SIZE); - wasmtime_global_t global; - error = wasmtime_global_new(context, self->var_i32_type, &value, &global); - assert(!error); - *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global}; + *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = self->stack_pointer_global}; } else if (name_eq(import_name, "__indirect_function_table")) { *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table}; } else if (name_eq(import_name, "memory")) { @@ -512,15 +494,21 @@ static bool ts_wasm_store__provide_builtin_import( // Builtin functions else if (name_eq(import_name, "__assert_fail")) { - *import = get_builtin_func_extern(context, &self->function_table, ASSERT_FAIL_IX); + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.assert_fail); } else if (name_eq(import_name, "__cxa_atexit")) { - *import = get_builtin_func_extern(context, &self->function_table, AT_EXIT_IX); + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.at_exit); + } else if (name_eq(import_name, "args_get")) { + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_get); + } else if (name_eq(import_name, "args_sizes_get")) { + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_sizes_get); } else if (name_eq(import_name, "abort")) { - *import = get_builtin_func_extern(context, &self->function_table, ABORT_IX); + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.abort); } else if (name_eq(import_name, "proc_exit")) { - *import = get_builtin_func_extern(context, &self->function_table, PROC_EXIT_IX); + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.proc_exit); } else if (name_eq(import_name, "emscripten_notify_memory_growth")) { - *import = get_builtin_func_extern(context, &self->function_table, NOTIFY_MEMORY_GROWTH_IX); + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.notify_memory_growth); + } else if (name_eq(import_name, "tree_sitter_debug_message")) { + *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.debug_message); } else { return false; } @@ -536,7 +524,8 @@ static bool ts_wasm_store__call_module_initializer( ) { if ( name_eq(export_name, "_initialize") || - name_eq(export_name, "__wasm_apply_data_relocs") + name_eq(export_name, "__wasm_apply_data_relocs") || + name_eq(export_name, "__wasm_call_ctors") ) { wasmtime_context_t *context = wasmtime_store_context(self->store); wasmtime_func_t initialization_func = export->of.func; @@ -549,141 +538,216 @@ static bool ts_wasm_store__call_module_initializer( } TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { - TSWasmStore *self = ts_malloc(sizeof(TSWasmStore)); + TSWasmStore *self = ts_calloc(1, sizeof(TSWasmStore)); wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL); wasmtime_context_t *context = wasmtime_store_context(store); wasmtime_error_t *error = NULL; wasm_trap_t *trap = NULL; wasm_message_t message = WASM_EMPTY_VEC; wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC; + wasmtime_extern_t *imports = NULL; + wasmtime_module_t *stdlib_module = NULL; + wasm_memorytype_t *memory_type = NULL; + wasm_tabletype_t *table_type = NULL; - // Initialize store's memory - wasm_limits_t memory_limits = {.min = INITIAL_MEMORY_SIZE, .max = MAX_MEMORY_SIZE}; - wasm_memorytype_t *memory_type = wasm_memorytype_new(&memory_limits); - wasmtime_memory_t memory; - error = wasmtime_memory_new(context, memory_type, &memory); + // Define functions called by scanners via function pointers on the lexer. + LexerInWasmMemory lexer = { + .lookahead = 0, + .result_symbol = 0, + }; + FunctionDefinition lexer_definitions[] = { + { + (uint32_t *)&lexer.advance, + callback__lexer_advance, + wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + (uint32_t *)&lexer.mark_end, + callback__lexer_mark_end, + wasm_functype_new_1_0(wasm_valtype_new_i32()) + }, + { + (uint32_t *)&lexer.get_column, + callback__lexer_get_column, + wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + (uint32_t *)&lexer.is_at_included_range_start, + callback__lexer_is_at_included_range_start, + wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + (uint32_t *)&lexer.eof, + callback__lexer_eof, + wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + }; + + // Define builtin functions that can be imported by scanners. + BuiltinFunctionIndices builtin_fn_indices; + FunctionDefinition builtin_definitions[] = { + { + &builtin_fn_indices.proc_exit, + callback__abort, + wasm_functype_new_1_0(wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.abort, + callback__abort, + wasm_functype_new_0_0() + }, + { + &builtin_fn_indices.assert_fail, + callback__abort, + wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.notify_memory_growth, + callback__noop, + wasm_functype_new_1_0(wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.debug_message, + callback__debug_message, + wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.at_exit, + callback__noop, + wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.args_get, + callback__noop, + wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + { + &builtin_fn_indices.args_sizes_get, + callback__noop, + wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) + }, + }; + + // Create all of the wasm functions. + unsigned builtin_definitions_len = array_len(builtin_definitions); + unsigned lexer_definitions_len = array_len(lexer_definitions); + for (unsigned i = 0; i < builtin_definitions_len; i++) { + FunctionDefinition *definition = &builtin_definitions[i]; + wasmtime_func_t func; + wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func); + *definition->storage_location = func.index; + wasm_functype_delete(definition->type); + } + for (unsigned i = 0; i < lexer_definitions_len; i++) { + FunctionDefinition *definition = &lexer_definitions[i]; + wasmtime_func_t func; + wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func); + *definition->storage_location = func.index; + wasm_functype_delete(definition->type); + } + + // Compile the stdlib module. + error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module); if (error) { wasmtime_error_message(error, &message); - wasm_error->kind = TSWasmErrorKindAllocate; + wasm_error->kind = TSWasmErrorKindCompile; format( &wasm_error->message, - "failed to allocate wasm memory: %.*s", + "failed to compile wasm stdlib: %.*s", (int)message.size, message.data ); goto error; } - wasm_memorytype_delete(memory_type); - // Initialize lexer struct with function pointers in wasm memory. - uint8_t *memory_data = wasmtime_memory_data(context, &memory); - LexerInWasmMemory lexer = { - .lookahead = 0, - .result_symbol = 0, - .advance = LEXER_ADVANCE_IX, - .mark_end = LEXER_MARK_END_IX, - .get_column = LEXER_GET_COLUMN_IX, - .is_at_included_range_start = LEXER_IS_AT_INCLUDED_RANGE_START_IX, - .eof = LEXER_EOF_IX, - }; - memcpy(&memory_data[LEXER_ADDRESS], &lexer, sizeof(lexer)); - - // Define builtin functions. - FunctionDefinition definitions[] = { - [NULL_IX] = {NULL, NULL}, - [PROC_EXIT_IX] = {callback__exit, wasm_functype_new_1_0(wasm_valtype_new_i32())}, - [ABORT_IX] = {callback__exit, wasm_functype_new_0_0()}, - [ASSERT_FAIL_IX] = {callback__exit, wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [NOTIFY_MEMORY_GROWTH_IX] = {callback__notify_memory_growth, wasm_functype_new_1_0(wasm_valtype_new_i32())}, - [AT_EXIT_IX] = {callback__at_exit, wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [LEXER_ADVANCE_IX] = {callback__lexer_advance, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [LEXER_MARK_END_IX] = {callback__lexer_mark_end, wasm_functype_new_1_0(wasm_valtype_new_i32())}, - [LEXER_GET_COLUMN_IX] = {callback__lexer_get_column, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [LEXER_IS_AT_INCLUDED_RANGE_START_IX] = {callback__lexer_is_at_included_range_start, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [LEXER_EOF_IX] = {callback__lexer_eof, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - }; - unsigned definitions_len = array_len(definitions); + // Retrieve the stdlib module's imports. + wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; + wasmtime_module_imports(stdlib_module, &import_types); - // Add builtin functions to the store's function table. - wasmtime_table_t function_table; - wasm_limits_t table_limits = {.min = definitions_len, .max = wasm_limits_max_default}; - wasm_tabletype_t *table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); - wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF}; - error = wasmtime_table_new(context, table_type, &initializer, &function_table); + // Find the initial number of memory pages needed by the stdlib. + const wasm_memorytype_t *stdlib_memory_type; + for (unsigned i = 0; i < import_types.size; i++) { + wasm_importtype_t *import_type = import_types.data[i]; + const wasm_name_t *import_name = wasm_importtype_name(import_type); + if (name_eq(import_name, "memory")) { + const wasm_externtype_t *type = wasm_importtype_type(import_type); + stdlib_memory_type = wasm_externtype_as_memorytype_const(type); + } + } + if (!stdlib_memory_type) { + wasm_error->kind = TSWasmErrorKindCompile; + format( + &wasm_error->message, + "wasm stdlib is missing the 'memory' import" + ); + goto error; + } + + // Initialize store's memory + uint64_t initial_memory_pages = wasmtime_memorytype_minimum(stdlib_memory_type); + wasm_limits_t memory_limits = {.min = initial_memory_pages, .max = MAX_MEMORY_SIZE}; + memory_type = wasm_memorytype_new(&memory_limits); + wasmtime_memory_t memory; + error = wasmtime_memory_new(context, memory_type, &memory); if (error) { wasmtime_error_message(error, &message); wasm_error->kind = TSWasmErrorKindAllocate; format( &wasm_error->message, - "failed to allocate wasm table: %.*s", + "failed to allocate wasm memory: %.*s", (int)message.size, message.data ); goto error; } - wasm_tabletype_delete(table_type); + wasm_memorytype_delete(memory_type); + memory_type = NULL; - uint32_t prev_size; - error = wasmtime_table_grow(context, &function_table, definitions_len, &initializer, &prev_size); + // Initialize store's function table + wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default}; + table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); + wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF}; + wasmtime_table_t function_table; + error = wasmtime_table_new(context, table_type, &initializer, &function_table); if (error) { wasmtime_error_message(error, &message); wasm_error->kind = TSWasmErrorKindAllocate; format( &wasm_error->message, - "failed to grow wasm table to initial size: %.*s", + "failed to allocate wasm table: %.*s", (int)message.size, message.data ); goto error; } + wasm_tabletype_delete(table_type); + table_type = NULL; - for (unsigned i = 1; i < definitions_len; i++) { - FunctionDefinition *definition = &definitions[i]; - wasmtime_func_t func; - wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func); - wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func}; - error = wasmtime_table_set(context, &function_table, i, &func_val); - assert(!error); - wasm_functype_delete(definition->type); - } + unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS); + + // Define globals for the stack and heap start addresses. + wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST); + wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR); + + wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0); + wasmtime_global_t stack_pointer_global; + error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global); + assert(!error); *self = (TSWasmStore) { - .store = store, .engine = engine, + .store = store, .memory = memory, - .language_instances = array_new(), .function_table = function_table, + .language_instances = array_new(), + .stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)), + .builtin_fn_indices = builtin_fn_indices, + .stack_pointer_global = stack_pointer_global, .current_memory_offset = 0, - .current_memory_size = 64 * MEMORY_PAGE_SIZE, - .current_function_table_offset = definitions_len, - .const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST), - .var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR), + .current_function_table_offset = 0, + .const_i32_type = const_i32_type, }; - WasmDylinkInfo dylink_info; - if (!wasm_dylink_info__parse(STDLIB_WASM, STDLIB_WASM_LEN, &dylink_info)) { - wasm_error->kind = TSWasmErrorKindParse; - format(&wasm_error->message, "failed to parse wasm stdlib"); - goto error; - } - - wasmtime_module_t *stdlib_module; - error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module); - if (error) { - wasmtime_error_message(error, &message); - wasm_error->kind = TSWasmErrorKindCompile; - format( - &wasm_error->message, - "failed to compile wasm stdlib: %.*s", - (int)message.size, message.data - ); - goto error; - } - - wasmtime_instance_t instance; - wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; - wasmtime_module_imports(stdlib_module, &import_types); - if (import_types.size > MAX_IMPORT_COUNT) goto error; - - wasmtime_extern_t imports[MAX_IMPORT_COUNT]; - for (unsigned i = 0; i < import_types.size && i < MAX_IMPORT_COUNT; i++) { + // Set up the imports for the stdlib module. + imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t)); + for (unsigned i = 0; i < import_types.size; i++) { wasm_importtype_t *type = import_types.data[i]; const wasm_name_t *import_name = wasm_importtype_name(type); if (!ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) { @@ -697,7 +761,11 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } } + // Instantiate the stdlib module. + wasmtime_instance_t instance; error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap); + ts_free(imports); + imports = NULL; if (error) { wasmtime_error_message(error, &message); wasm_error->kind = TSWasmErrorKindInstantiate; @@ -720,14 +788,10 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } wasm_importtype_vec_delete(&import_types); - self->current_memory_offset = DATA_START_ADDRESS + dylink_info.memory_size; - self->current_function_table_offset += dylink_info.table_size; - - for (unsigned i = 0; i < STDLIB_SYMBOL_COUNT; i++) { - self->fn_indices[i] = UINT16_MAX; - } - // Process the stdlib module's exports. + for (unsigned i = 0; i < stdlib_symbols_len; i++) { + self->stdlib_fn_indices[i] = UINT32_MAX; + } wasmtime_module_exports(stdlib_module, &export_types); for (unsigned i = 0; i < export_types.size; i++) { wasm_exporttype_t *export_type = export_types.data[i]; @@ -739,6 +803,12 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export); assert(exists); + if (export.kind == WASMTIME_EXTERN_GLOBAL) { + if (name_eq(name, "__stack_pointer")) { + self->stack_pointer_global = export.of.global; + } + } + if (export.kind == WASMTIME_EXTERN_FUNC) { if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) { if (trap) { @@ -754,17 +824,31 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { continue; } - for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) { + if (name_eq(name, "reset_heap")) { + self->builtin_fn_indices.reset_heap = export.of.func.index; + continue; + } + + for (unsigned j = 0; j < stdlib_symbols_len; j++) { if (name_eq(name, STDLIB_SYMBOLS[j])) { - self->fn_indices[j] = export.of.func.index; + self->stdlib_fn_indices[j] = export.of.func.index; break; } } } } - for (unsigned i = 0; i < STDLIB_SYMBOL_COUNT; i++) { - if (self->fn_indices[i] == UINT16_MAX) { + if (self->builtin_fn_indices.reset_heap == UINT32_MAX) { + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "missing malloc reset function in wasm stdlib" + ); + goto error; + } + + for (unsigned i = 0; i < stdlib_symbols_len; i++) { + if (self->stdlib_fn_indices[i] == UINT32_MAX) { wasm_error->kind = TSWasmErrorKindInstantiate; format( &wasm_error->message, @@ -776,28 +860,86 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } wasm_exporttype_vec_delete(&export_types); + wasmtime_module_delete(stdlib_module); + + // Add all of the lexer callback functions to the function table. Store their function table + // indices on the in-memory lexer. + uint32_t table_index; + error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &table_index); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to grow wasm table to initial size: %.*s", + (int)message.size, message.data + ); + goto error; + } + for (unsigned i = 0; i < lexer_definitions_len; i++) { + FunctionDefinition *definition = &lexer_definitions[i]; + wasmtime_func_t func = {function_table.store_id, *definition->storage_location}; + wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func}; + error = wasmtime_table_set(context, &function_table, table_index, &func_val); + assert(!error); + *(int32_t *)(definition->storage_location) = table_index; + table_index++; + } + + self->current_function_table_offset = table_index; + self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE; + self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory); + self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE; + + // Grow the memory enough to hold the builtin lexer and serialization buffer. + uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1; + uint64_t prev_memory_size; + wasmtime_memory_grow(context, &memory, new_pages_needed, &prev_memory_size); + + uint8_t *memory_data = wasmtime_memory_data(context, &memory); + memcpy(&memory_data[self->lexer_address], &lexer, sizeof(lexer)); return self; error: ts_free(self); + if (stdlib_module) wasmtime_module_delete(stdlib_module); if (store) wasmtime_store_delete(store); + if (import_types.size) wasm_importtype_vec_delete(&import_types); + if (memory_type) wasm_memorytype_delete(memory_type); + if (table_type) wasm_tabletype_delete(table_type); if (trap) wasm_trap_delete(trap); if (error) wasmtime_error_delete(error); if (message.size) wasm_byte_vec_delete(&message); if (export_types.size) wasm_exporttype_vec_delete(&export_types); + if (imports) ts_free(imports); return NULL; } void ts_wasm_store_delete(TSWasmStore *self) { if (!self) return; + ts_free(self->stdlib_fn_indices); wasm_globaltype_delete(self->const_i32_type); - wasm_globaltype_delete(self->var_i32_type); wasmtime_store_delete(self->store); wasm_engine_delete(self->engine); + for (unsigned i = 0; i < self->language_instances.size; i++) { + LanguageWasmInstance *instance = &self->language_instances.contents[i]; + language_id_delete(instance->language_id); + } array_delete(&self->language_instances); ts_free(self); } +size_t ts_wasm_store_language_count(const TSWasmStore *self) { + size_t result = 0; + for (unsigned i = 0; i < self->language_instances.size; i++) { + const WasmLanguageId *id = self->language_instances.contents[i].language_id; + if (!id->is_language_deleted) { + result++; + } + } + return result; +} + static bool ts_wasm_store__instantiate( TSWasmStore *self, wasmtime_module_t *module, @@ -811,6 +953,7 @@ static bool ts_wasm_store__instantiate( wasm_trap_t *trap = NULL; wasm_message_t message = WASM_EMPTY_VEC; char *language_function_name = NULL; + wasmtime_extern_t *imports = NULL; wasmtime_context_t *context = wasmtime_store_context(self->store); // Grow the function table to make room for the new functions. @@ -824,9 +967,10 @@ static bool ts_wasm_store__instantiate( // Grow the memory to make room for the new data. uint32_t needed_memory_size = self->current_memory_offset + dylink_info->memory_size; - if (needed_memory_size > self->current_memory_size) { + uint32_t current_memory_size = wasmtime_memory_data_size(context, &self->memory); + if (needed_memory_size > current_memory_size) { uint32_t pages_to_grow = ( - needed_memory_size - self->current_memory_size + MEMORY_PAGE_SIZE - 1) / + needed_memory_size - current_memory_size + MEMORY_PAGE_SIZE - 1) / MEMORY_PAGE_SIZE; uint64_t prev_memory_size; error = wasmtime_memory_grow(context, &self->memory, pages_to_grow, &prev_memory_size); @@ -834,7 +978,6 @@ static bool ts_wasm_store__instantiate( format(error_message, "invalid memory size %u", dylink_info->memory_size); goto error; } - self->current_memory_size += pages_to_grow * MEMORY_PAGE_SIZE; } // Construct the language function name as string. @@ -845,8 +988,7 @@ static bool ts_wasm_store__instantiate( // Build the imports list for the module. wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; wasmtime_module_imports(module, &import_types); - if (import_types.size > MAX_IMPORT_COUNT) goto error; - wasmtime_extern_t imports[MAX_IMPORT_COUNT]; + imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t)); for (unsigned i = 0; i < import_types.size; i++) { const wasm_importtype_t *import_type = import_types.data[i]; @@ -863,7 +1005,7 @@ static bool ts_wasm_store__instantiate( bool defined_in_stdlib = false; for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) { if (name_eq(import_name, STDLIB_SYMBOLS[j])) { - uint16_t address = self->fn_indices[j]; + uint16_t address = self->stdlib_fn_indices[j]; imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = {store_id, address}}; defined_in_stdlib = true; break; @@ -883,6 +1025,8 @@ static bool ts_wasm_store__instantiate( wasmtime_instance_t instance; error = wasmtime_instance_new(context, module, imports, import_types.size, &instance, &trap); wasm_importtype_vec_delete(&import_types); + ts_free(imports); + imports = NULL; if (error) { wasmtime_error_message(error, &message); format( @@ -984,6 +1128,7 @@ static bool ts_wasm_store__instantiate( if (message.size) wasm_byte_vec_delete(&message); if (error) wasmtime_error_delete(error); if (trap) wasm_trap_delete(trap); + if (imports) ts_free(imports); return false; } @@ -1043,6 +1188,12 @@ const TSLanguage *ts_wasm_store_load_language( const uint8_t *memory = wasmtime_memory_data(context, &self->memory); memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory)); + if (wasm_language.version < LANGUAGE_VERSION_USABLE_VIA_WASM) { + wasm_error->kind = TSWasmErrorKindInstantiate; + format(&wasm_error->message, "language version %u is too old for wasm", wasm_language.version); + goto error; + } + int32_t addresses[] = { wasm_language.alias_map, wasm_language.alias_sequences, @@ -1074,7 +1225,7 @@ const TSLanguage *ts_wasm_store_load_language( }; uint32_t address_count = array_len(addresses); - TSLanguage *language = ts_malloc(sizeof(TSLanguage)); + TSLanguage *language = ts_calloc(1, sizeof(TSLanguage)); StringData symbol_name_buffer = array_new(); StringData field_name_buffer = array_new(); @@ -1174,7 +1325,7 @@ const TSLanguage *ts_wasm_store_load_language( ); } - if (language->version >= 14) { + if (language->version >= LANGUAGE_VERSION_WITH_PRIMARY_STATES) { language->primary_state_ids = copy( &memory[wasm_language.primary_state_ids], wasm_language.state_count * sizeof(TSStateId) @@ -1196,12 +1347,13 @@ const TSLanguage *ts_wasm_store_load_language( LanguageWasmModule *language_module = ts_malloc(sizeof(LanguageWasmModule)); *language_module = (LanguageWasmModule) { - .language_id = atomic_inc(&NEXT_LANGUAGE_ID), + .language_id = language_id_new(), .module = module, .name = name, .symbol_name_buffer = symbol_name_buffer.contents, .field_name_buffer = field_name_buffer.contents, .dylink_info = dylink_info, + .ref_count = 1, }; // The lex functions are not used for wasm languages. Use those two fields @@ -1210,10 +1362,19 @@ const TSLanguage *ts_wasm_store_load_language( language->lex_fn = ts_wasm_store__sentinel_lex_fn; language->keyword_lex_fn = (void *)language_module; - // Store some information about this store's specific instance of this - // language module, keyed by the language's id. + // Clear out any instances of languages that have been deleted. + for (unsigned i = 0; i < self->language_instances.size; i++) { + WasmLanguageId *id = self->language_instances.contents[i].language_id; + if (id->is_language_deleted) { + language_id_delete(id); + array_erase(&self->language_instances, i); + i--; + } + } + + // Store this store's instance of this language module. array_push(&self->language_instances, ((LanguageWasmInstance) { - .language_id = language_module->language_id, + .language_id = language_id_clone(language_module->language_id), .instance = instance, .external_states_address = wasm_language.external_scanner.states, .lex_main_fn_index = wasm_language.lex_fn, @@ -1240,19 +1401,25 @@ bool ts_wasm_store_add_language( wasmtime_context_t *context = wasmtime_store_context(self->store); const LanguageWasmModule *language_module = (void *)language->keyword_lex_fn; - // Search for the information about this store's instance of the language module. + // Search for this store's instance of the language module. Also clear out any + // instances of languages that have been deleted. bool exists = false; - array_search_sorted_by( - &self->language_instances, - .language_id, - language_module->language_id, - index, - &exists - ); + for (unsigned i = 0; i < self->language_instances.size; i++) { + WasmLanguageId *id = self->language_instances.contents[i].language_id; + if (id->is_language_deleted) { + language_id_delete(id); + array_erase(&self->language_instances, i); + i--; + } else if (id == language_module->language_id) { + exists = true; + *index = i; + } + } // If the language module has not been instantiated in this store, then add // it to this store. if (!exists) { + *index = self->language_instances.size; char *message; wasmtime_instance_t instance; int32_t language_address; @@ -1272,8 +1439,8 @@ bool ts_wasm_store_add_language( LanguageInWasmMemory wasm_language; const uint8_t *memory = wasmtime_memory_data(context, &self->memory); memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory)); - array_insert(&self->language_instances, *index, ((LanguageWasmInstance) { - .language_id = language_module->language_id, + array_push(&self->language_instances, ((LanguageWasmInstance) { + .language_id = language_id_clone(language_module->language_id), .instance = instance, .external_states_address = wasm_language.external_scanner.states, .lex_main_fn_index = wasm_language.lex_fn, @@ -1289,17 +1456,37 @@ bool ts_wasm_store_add_language( return true; } +void ts_wasm_store_reset_heap(TSWasmStore *self) { + wasmtime_context_t *context = wasmtime_store_context(self->store); + wasmtime_func_t func = { + self->function_table.store_id, + self->builtin_fn_indices.reset_heap + }; + wasm_trap_t *trap = NULL; + wasmtime_val_t args[1] = { + {.of.i32 = self->current_memory_offset, .kind = WASMTIME_I32}, + }; + + wasmtime_error_t *error = wasmtime_func_call(context, &func, args, 1, NULL, 0, &trap); + assert(!error); + assert(!trap); +} + bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) { uint32_t instance_index; if (!ts_wasm_store_add_language(self, language, &instance_index)) return false; self->current_lexer = lexer; self->current_instance = &self->language_instances.contents[instance_index]; + self->has_error = false; + ts_wasm_store_reset_heap(self); return true; } -void ts_wasm_store_stop(TSWasmStore *self) { +void ts_wasm_store_reset(TSWasmStore *self) { self->current_lexer = NULL; self->current_instance = NULL; + self->has_error = false; + ts_wasm_store_reset_heap(self); } static void ts_wasm_store__call( @@ -1317,17 +1504,26 @@ static void ts_wasm_store__call( wasm_trap_t *trap = NULL; wasmtime_error_t *error = wasmtime_func_call_unchecked(context, &func, args_and_results, args_and_results_len, &trap); - assert(!error); - if (trap) { - wasm_message_t message; - wasm_trap_message(trap, &message); - fprintf( - stderr, - "trap when calling wasm lexing function %u: %.*s\n", - function_index, - (int)message.size, message.data - ); - abort(); + if (error) { + // wasm_message_t message; + // wasmtime_error_message(error, &message); + // fprintf( + // stderr, + // "error in wasm module: %.*s\n", + // (int)message.size, message.data + // ); + wasmtime_error_delete(error); + self->has_error = true; + } else if (trap) { + // wasm_message_t message; + // wasm_trap_message(trap, &message); + // fprintf( + // stderr, + // "trap in wasm module: %.*s\n", + // (int)message.size, message.data + // ); + wasm_trap_delete(trap); + self->has_error = true; } } @@ -1335,21 +1531,22 @@ static bool ts_wasm_store__call_lex_function(TSWasmStore *self, unsigned functio wasmtime_context_t *context = wasmtime_store_context(self->store); uint8_t *memory_data = wasmtime_memory_data(context, &self->memory); memcpy( - &memory_data[LEXER_ADDRESS], + &memory_data[self->lexer_address], &self->current_lexer->lookahead, sizeof(self->current_lexer->lookahead) ); wasmtime_val_raw_t args[2] = { - {.i32 = LEXER_ADDRESS}, + {.i32 = self->lexer_address}, {.i32 = state}, }; ts_wasm_store__call(self, function_index, args, 2); + if (self->has_error) return false; bool result = args[0].i32; memcpy( &self->current_lexer->lookahead, - &memory_data[LEXER_ADDRESS], + &memory_data[self->lexer_address], sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol) ); return result; @@ -1374,12 +1571,15 @@ bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) { uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) { wasmtime_val_raw_t args[1] = {{.i32 = 0}}; ts_wasm_store__call(self, self->current_instance->scanner_create_fn_index, args, 1); + if (self->has_error) return 0; return args[0].i32; } void ts_wasm_store_call_scanner_destroy(TSWasmStore *self, uint32_t scanner_address) { - wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}}; - ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1); + if (self->current_instance) { + wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}}; + ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1); + } } bool ts_wasm_store_call_scanner_scan( @@ -1391,7 +1591,7 @@ bool ts_wasm_store_call_scanner_scan( uint8_t *memory_data = wasmtime_memory_data(context, &self->memory); memcpy( - &memory_data[LEXER_ADDRESS], + &memory_data[self->lexer_address], &self->current_lexer->lookahead, sizeof(self->current_lexer->lookahead) ); @@ -1401,14 +1601,15 @@ bool ts_wasm_store_call_scanner_scan( (valid_tokens_ix * sizeof(bool)); wasmtime_val_raw_t args[3] = { {.i32 = scanner_address}, - {.i32 = LEXER_ADDRESS}, + {.i32 = self->lexer_address}, {.i32 = valid_tokens_address} }; ts_wasm_store__call(self, self->current_instance->scanner_scan_fn_index, args, 3); + if (self->has_error) return false; memcpy( &self->current_lexer->lookahead, - &memory_data[LEXER_ADDRESS], + &memory_data[self->lexer_address], sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol) ); return args[0].i32; @@ -1424,15 +1625,17 @@ uint32_t ts_wasm_store_call_scanner_serialize( wasmtime_val_raw_t args[2] = { {.i32 = scanner_address}, - {.i32 = SERIALIZATION_BUFFER_ADDRESS}, + {.i32 = self->serialization_buffer_address}, }; ts_wasm_store__call(self, self->current_instance->scanner_serialize_fn_index, args, 2); + if (self->has_error) return 0; + uint32_t length = args[0].i32; if (length > 0) { memcpy( ((Lexer *)self->current_lexer)->debug_buffer, - &memory_data[SERIALIZATION_BUFFER_ADDRESS], + &memory_data[self->serialization_buffer_address], length ); } @@ -1450,7 +1653,7 @@ void ts_wasm_store_call_scanner_deserialize( if (length > 0) { memcpy( - &memory_data[SERIALIZATION_BUFFER_ADDRESS], + &memory_data[self->serialization_buffer_address], buffer, length ); @@ -1458,16 +1661,64 @@ void ts_wasm_store_call_scanner_deserialize( wasmtime_val_raw_t args[3] = { {.i32 = scanner_address}, - {.i32 = SERIALIZATION_BUFFER_ADDRESS}, + {.i32 = self->serialization_buffer_address}, {.i32 = length}, }; ts_wasm_store__call(self, self->current_instance->scanner_deserialize_fn_index, args, 3); } +bool ts_wasm_store_has_error(const TSWasmStore *self) { + return self->has_error; +} + bool ts_language_is_wasm(const TSLanguage *self) { return self->lex_fn == ts_wasm_store__sentinel_lex_fn; } +static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) { + return (LanguageWasmModule *)self->keyword_lex_fn; +} + +void ts_wasm_language_retain(const TSLanguage *self) { + LanguageWasmModule *module = ts_language__wasm_module(self); + assert(module->ref_count > 0); + atomic_inc(&module->ref_count); +} + +void ts_wasm_language_release(const TSLanguage *self) { + LanguageWasmModule *module = ts_language__wasm_module(self); + assert(module->ref_count > 0); + if (atomic_dec(&module->ref_count) == 0) { + // Update the language id to reflect that the language is deleted. This allows any wasm stores + // that hold wasm instances for this language to delete those instances. + atomic_inc(&module->language_id->is_language_deleted); + language_id_delete(module->language_id); + + ts_free((void *)module->field_name_buffer); + ts_free((void *)module->symbol_name_buffer); + ts_free((void *)module->name); + wasmtime_module_delete(module->module); + ts_free(module); + + ts_free((void *)self->alias_map); + ts_free((void *)self->alias_sequences); + ts_free((void *)self->external_scanner.symbol_map); + ts_free((void *)self->field_map_entries); + ts_free((void *)self->field_map_slices); + ts_free((void *)self->field_names); + ts_free((void *)self->lex_modes); + ts_free((void *)self->parse_actions); + ts_free((void *)self->parse_table); + ts_free((void *)self->primary_state_ids); + ts_free((void *)self->public_symbol_map); + ts_free((void *)self->small_parse_table); + ts_free((void *)self->small_parse_table_map); + ts_free((void *)self->symbol_metadata); + ts_free((void *)self->symbol_names); + ts_free((void *)self); + } +} + #else // If the WASM feature is not enabled, define dummy versions of all of the @@ -1488,7 +1739,7 @@ bool ts_wasm_store_start( return false; } -void ts_wasm_store_stop(TSWasmStore *self) { +void ts_wasm_store_reset(TSWasmStore *self) { (void)self; } @@ -1551,9 +1802,22 @@ void ts_wasm_store_call_scanner_deserialize( (void)length; } +bool ts_wasm_store_has_error(const TSWasmStore *self) { + (void)self; + return false; +} + bool ts_language_is_wasm(const TSLanguage *self) { (void)self; return false; } +void ts_wasm_language_retain(const TSLanguage *self) { + (void)self; +} + +void ts_wasm_language_release(const TSLanguage *self) { + (void)self; +} + #endif diff --git a/wasm.h b/wasm_store.h similarity index 77% rename from wasm.h rename to wasm_store.h index bca586c2..1ad2ae57 100644 --- a/wasm.h +++ b/wasm_store.h @@ -6,10 +6,11 @@ extern "C" { #endif #include "api.h" -#include "parser.h" +#include "./parser.h" bool ts_wasm_store_start(TSWasmStore *, TSLexer *, const TSLanguage *); -void ts_wasm_store_stop(TSWasmStore *); +void ts_wasm_store_reset(TSWasmStore *); +bool ts_wasm_store_has_error(const TSWasmStore *); bool ts_wasm_store_call_lex_main(TSWasmStore *, TSStateId); bool ts_wasm_store_call_lex_keyword(TSWasmStore *, TSStateId); @@ -20,6 +21,9 @@ bool ts_wasm_store_call_scanner_scan(TSWasmStore *, uint32_t, uint32_t); uint32_t ts_wasm_store_call_scanner_serialize(TSWasmStore *, uint32_t, char *); void ts_wasm_store_call_scanner_deserialize(TSWasmStore *, uint32_t, const char *, unsigned); +void ts_wasm_language_retain(const TSLanguage *); +void ts_wasm_language_release(const TSLanguage *); + #ifdef __cplusplus } #endif