-
Notifications
You must be signed in to change notification settings - Fork 54
MySQL.jl
Package: MySQL.jl
Implementor: quinnj
Date: 2018-09-05
GitHub Pull Request: https://github.com/JuliaDatabases/MySQL.jl/pull/117
Ok, so looking over the package, it looks like it has a high-level function MySQL.query(conn, sql)
that takes a valid MySQL.Connection
conn
and a SQL string sql
, executes the statement, and then uses the DataStreams package to stream results out to any sink.
There's an internal type MySQL.Query
whose constructor actually holds the logic for preparing a statement, executing the result against the database, and doing some initial diligence on the layout of the resultset (column names, types), if any.
So to start off, we're doing to define:
Tables.istable(::Type{<:Query}) = true
Tables.rowaccess(::Type{<:Query}) = true
Tables.rows(q::Query) = q
To affirm that a MySQL.Query
object is a Tables.Table
and that it will iterate rows. Note I defined istable
and rowaccess
on the type <:Query
; this is because the type definition is:
mutable struct Query{hasresult, names, T}
result::Result
ptr::Ptr{Ptr{Int8}}
ncols::Int
nrows::Int
end
Meaning that MySQL.Query
is actually an abstract type, with Query{true, (:col1, :col2), Tuple{Int, Float64}}
being an example of a concrete instance of the Query
abstract type (see the manual section on parametric types for more info there). Put simply, I want to make sure that any concrete instance of Query
is recognized as a Table
w/ rowaccess
, regardless of the combination of type parameters any one instance may have.
For Tables.rows
, we just return the MySQL.Query
object itself, since it will be simplest to define iteration on it directly.
Ok, next definition:
Tables.schema(q::Query{hasresult, names, T}) where {hasresult, names, T} = Tables.Schema(names, T)
Here I'm defining Tables.schema
on the result of Tables.rows
, which just so happens to be the original "table" anyway. Luckily for us, the Query
object already encodes the column names and types as type parameters, so we can use those directly to call Tables.Schema(names, T)
to construct the schema object. Done.
Now all that's left is to actually ensure that the table (MySQL.Query
) can iterate rows! Let's start off with:
Base.length(q::Query) = q.ptr == C_NULL ? 0 : q.nrows
Base.eltype(q::Query{hasresult, names, types}) where {hasresult, names, types} = NamedTuple{names, types}
Luckily for us, the Query
type stores the number of rows in the result, so we can just access that directly for our length
definition; we also check q.ptr == C_NULL
in case a user tries to re-iterate a Query
; a Query
execution is a read-forward-once-only iterator, so making the length 0 in that case protects any iterator misuse.
Similarly for eltype
, we can again use the encoded names
and types
to create a NamedTuple{names, types}
type object to signal what we'll be iterating.
Ok, let's get to the main course here: iterate
. First, let's familiarize ourselves with the helper functions that are defined for getting an individual column's value in a row:
cast(str, ::Type{Union{Missing, T}}) where {T} = cast(str, T)
cast(str, ::Type{API.Bit}) = API.Bit(isempty(str) ? 0 : UInt64(str[1]))
cast(str, ::Type{T}) where {T<:Number} = parse(T, str)
cast(str, ::Type{Vector{UInt8}}) = Vector{UInt8}(str)
cast(str, ::Type{<:AbstractString}) = str
cast(str, ::Type{Time}) = mysql_time(str)
cast(str, ::Type{Date}) = mysql_date(str)
cast(str, ::Type{DateTime}) = mysql_datetime(str)
function getvalue(ptr, col, ::Type{T}) where {T}
deref = unsafe_load(ptr, col)
return deref == C_NULL ? missing : cast(unsafe_string(deref), T)
end
So we have getvalue
, which takes a pointer ptr
, col
column index, and a type T
; it first dereferences the pointer according to our column index, if the dereferenced pointer is NULL
, it's easy, we return missing
(I guess we can asssume that a non-nullable column would never return a null pointer in this case). In the non-null case, we get the value string unsafe_string(deref)
, and then call cast(str, T)
to convert the string into the appropriate type. Ok, a little C-ish, but not too bad so far. Now let's do some fancy Julia code:
function Base.iterate(q::Query{hasresult, names, types}, st=1) where {hasresult, names, types}
st > length(q) && return nothing
!hasresult && return (num_rows_affected=Int(q.result.ptr),), 2
nt = generate_namedtuple(NamedTuple{names, types}, q)
q.ptr = API.mysql_fetch_row(q.result.ptr)
return nt, st + 1
end```
Here we define `iterate` over our `Query` object, with the state just simply the row number. If the row number `st` is ever larger than the `length(q)`, we'll just return `nothing` indicating there's nothing left to iterate. Reading over the `Query` source code, it also looks like in the case that an SQL query doesn't actually return any rows (in SQL world, this is known as a DDL statement, that alters state in the database, but doesn't return any data), it still returns a single row w/ a single column: `(num_rows_affected=value,)`. This is flagged by the `hasresult` type parameter, so we'll check that and return the number of affected rows, and increment the `st` to 2 to ensure that the next iteration will return `nothing`.
Ok, so now that we know there's an actual row to be iterated, we make a little call like `nt = generate_namedtuple(NamedTuple{names, types}, q)` to get our row, we'll go over that function definition in just a bit. Once we have our NamedTuple, we want to make sure the next iteration has the next row ready, so we call `source.ptr = API.mysql_fetch_Row(source.result.ptr)`, which will fetch the next row of values from the database. Finally, we can `return nt, st + 1` from our iterate call which is the row of values (as a NamedTuple) and the next "state", or in this case, row number.
Let's rewind just a bit to that `generate_namedtuple` call and take a deeper look. If you're not familiar or care about how generated functions work in julia, feel free to move on, but it's nothing too crazy that allows our iterate function to be type stable. Ok, here's the generated function definition:
```julia
function generate_namedtuple(::Type{NamedTuple{names, types}}, q) where {names, types}
if @generated
vals = Tuple(:(getvalue(q.ptr, $i, $(fieldtype(types, i)))) for i = 1:fieldcount(types))
return :(NamedTuple{names, types}(($(vals...),)))
else
return NamedTuple{names, types}(Tuple(getvalue(q.ptr, i, fieldtype(types, i)) for i = 1:fieldcount(types)))
end
end
Ok, let's take this apart, piece by piece. So we're passing in a NamedTuple type as the first argument, with type parameters names
and types
, which are "known" at compile time, allowing us to generate some code expressions based on those parameters. The if @generated
block signals to the compiler that we want to take the types of the function arguments and return some kind of Expr object that will end up being compiled as the final method body. The else
block provides a fallback for the compiler, in case the generated portion fails or in a runtime where the compiler isn't available, there'd still be a codepath that doesn't rely on dynamically compiling a new method body. In our case, return NamedTuple{names, types}(Tuple(getvalue(q.ptr, i, fieldtype(types, i)) for i = 1:fieldcount(types)))
is this fallback, but is less than ideal because the result type is not inferrable. In the @generated
block, however, we can dynamically generate the inner "tuple" of our NamedTuple by first generating vals = Tuple(:(getvalue(q.ptr, $i, $(fieldtype(types, i)))) for i = 1:fieldcount(types))
, which is a tuple of Expr objects. Then we splice those expressions in to a NamedTuple constructor in a tuple. The resulting expression would look something like:
NamedTuple{(:a, :b), Tuple{Int, Int}((getvalue(q.ptr, 1, Int), getvalue(q.ptr, 2, Int)))
Which is a statically inferrable return type, which is important for our iterate
function which may be called thousands of times while returning a result set.
Now our final remaining task is to enable "sink" support for a mysql database by allowing any Tables.jl interface input to be loaded into the database. The MySQL.jl package already has a MySQL.Stmt
type to facilitate making prepared statements that get compiled once and can be re-used over and over by re-binding different parameters, so that seems like the best place to start:
function execute!(itr, stmt::Stmt)
rows = Tables.rows(itr)
state = iterate(rows)
state === nothing && return stmt
row, st = state
sch = Tables.Schema(propertynames(row), nothing)
binds = Vector{API.MYSQL_BIND}(undef, stmt.nparams)
bindptr = pointer(binds)
while true
Tables.eachcolumn(sch, row) do val, col, nm
binds[col] = bind(val)
end
API.mysql_stmt_bind_param(stmt.ptr, bindptr) == 0 || throw(MySQLStatementError(stmt.ptr))
API.mysql_stmt_execute(stmt.ptr) == 0 || throw(MySQLStatementError(stmt.ptr))
stmt.rows_affected += API.mysql_stmt_affected_rows(stmt.ptr)
state = iterate(rows, st)
state === nothing && break
row, st = state
end
return stmt
end
Let's step through this function bit-by-bit: first we call Tables.rows(itr)
to get a Row
-iterator; databases naturally support inserting rows at a time via SQL statements like INSERT INTO tablename (column names...) VALUES (values...)
, so iterating over rows is natural. One thing we always want to consider in a sink is how to handle the case of an input table type having an "unknown schema". This is detected by a call to sch = Tables.schema(Tables.rows(x))
when sch === nothing
. In those cases, we can't rely on knowing the column names or types upfront before we start iterating. For the mysql case, we're not extremely interested in knowing them anyway, since as input here, the user is passing stmt::MySQL.Stmt
, in which they compiled their INSERT INTO ...
SQL statement, and which included the column names and how many parameters to include. So in this case, we don't even make a call to Tables.schema(rows)
, but rather we just begin iterating rows
. The machinery for binding values to SQL statement parameters is accomplished via the bind
function, and storing the result in a binds
buffer. The construct like
Tables.eachcolumn(sch, row) do val, col, nm
binds[col] = bind(val)
end
utilizes a helpful function from Tables.jl: eachcolumn
. eachcolumn
can actually take a full Tables.Schema
argument, or a half-schema like Tables.Schema(names, nothing)
and it will essentially "unroll" the inner loop operating over columns. It makes a call to get the value for each column in a row, and returns the column's index and name. This makes it extremely convenient to bind each row value as parameters in our INSERT statement. Once the values are bound, we call down to the API C-functions to send the values off to the database. After that, we basically just iterate again and check to see if we're done yet.