Skip to content

MySQL.jl

Jacob Quinn edited this page Sep 6, 2018 · 2 revisions

Package: MySQL.jl

Implementor: quinnj

Date: 2018-09-05

GitHub Pull Request: https://github.com/JuliaDatabases/MySQL.jl/pull/117

Making a MySQL.Query resultset a table

Ok, so looking over the package, it looks like it has a high-level function MySQL.query(conn, sql) that takes a valid MySQL.Connection conn and a SQL string sql, executes the statement, and then uses the DataStreams package to stream results out to any sink.

There's an internal type MySQL.Query whose constructor actually holds the logic for preparing a statement, executing the result against the database, and doing some initial diligence on the layout of the resultset (column names, types), if any.

So to start off, we're doing to define:

Tables.istable(::Type{<:Query}) = true
Tables.rowaccess(::Type{<:Query}) = true
Tables.rows(q::Query) = q

To affirm that a MySQL.Query object is a Tables.Table and that it will iterate rows. Note I defined istable and rowaccess on the type <:Query; this is because the type definition is:

mutable struct Query{hasresult, names, T}
    result::Result
    ptr::Ptr{Ptr{Int8}}
    ncols::Int
    nrows::Int
end

Meaning that MySQL.Query is actually an abstract type, with Query{true, (:col1, :col2), Tuple{Int, Float64}} being an example of a concrete instance of the Query abstract type (see the manual section on parametric types for more info there). Put simply, I want to make sure that any concrete instance of Query is recognized as a Table w/ rowaccess, regardless of the combination of type parameters any one instance may have.

For Tables.rows, we just return the MySQL.Query object itself, since it will be simplest to define iteration on it directly.

Ok, next definition:

Tables.schema(q::Query{hasresult, names, T}) where {hasresult, names, T} = Tables.Schema(names, T)

Here I'm defining Tables.schema on the result of Tables.rows, which just so happens to be the original "table" anyway. Luckily for us, the Query object already encodes the column names and types as type parameters, so we can use those directly to call Tables.Schema(names, T) to construct the schema object. Done.

Now all that's left is to actually ensure that the table (MySQL.Query) can iterate rows! Let's start off with:

Base.length(q::Query) = q.ptr == C_NULL ? 0 : q.nrows
Base.eltype(q::Query{hasresult, names, types}) where {hasresult, names, types} = NamedTuple{names, types}

Luckily for us, the Query type stores the number of rows in the result, so we can just access that directly for our length definition; we also check q.ptr == C_NULL in case a user tries to re-iterate a Query; a Query execution is a read-forward-once-only iterator, so making the length 0 in that case protects any iterator misuse. Similarly for eltype, we can again use the encoded names and types to create a NamedTuple{names, types} type object to signal what we'll be iterating.

Ok, let's get to the main course here: iterate. First, let's familiarize ourselves with the helper functions that are defined for getting an individual column's value in a row:

cast(str, ::Type{Union{Missing, T}}) where {T} = cast(str, T)
cast(str, ::Type{API.Bit}) = API.Bit(isempty(str) ? 0 : UInt64(str[1]))
cast(str, ::Type{T}) where {T<:Number} = parse(T, str)
cast(str, ::Type{Vector{UInt8}}) = Vector{UInt8}(str)
cast(str, ::Type{<:AbstractString}) = str
cast(str, ::Type{Time}) = mysql_time(str)
cast(str, ::Type{Date}) = mysql_date(str)
cast(str, ::Type{DateTime}) = mysql_datetime(str)

function getvalue(ptr, col, ::Type{T}) where {T}
    deref = unsafe_load(ptr, col)
    return deref == C_NULL ? missing : cast(unsafe_string(deref), T)
end

So we have getvalue, which takes a pointer ptr, col column index, and a type T; it first dereferences the pointer according to our column index, if the dereferenced pointer is NULL, it's easy, we return missing (I guess we can asssume that a non-nullable column would never return a null pointer in this case). In the non-null case, we get the value string unsafe_string(deref), and then call cast(str, T) to convert the string into the appropriate type. Ok, a little C-ish, but not too bad so far. Now let's do some fancy Julia code:

function Base.iterate(q::Query{hasresult, names, types}, st=1) where {hasresult, names, types}
    st > length(q) && return nothing
    !hasresult && return (num_rows_affected=Int(q.result.ptr),), 2
    nt = generate_namedtuple(NamedTuple{names, types}, q)
    q.ptr = API.mysql_fetch_row(q.result.ptr)
    return nt, st + 1
end```
Here we define `iterate` over our `Query` object, with the state just simply the row number. If the row number `st` is ever larger than the `length(q)`, we'll just return `nothing` indicating there's nothing left to iterate. Reading over the `Query` source code, it also looks like in the case that an SQL query doesn't actually return any rows (in SQL world, this is known as a DDL statement, that alters state in the database, but doesn't return any data), it still returns a single row w/ a single column: `(num_rows_affected=value,)`. This is flagged by the `hasresult` type parameter, so we'll check that and return the number of affected rows, and increment the `st` to 2 to ensure that the next iteration will return `nothing`.

Ok, so now that we know there's an actual row to be iterated, we make a little call like `nt = generate_namedtuple(NamedTuple{names, types}, q)` to get our row, we'll go over that function definition in just a bit. Once we have our NamedTuple, we want to make sure the next iteration has the next row ready, so we call `source.ptr = API.mysql_fetch_Row(source.result.ptr)`, which will fetch the next row of values from the database. Finally, we can `return nt, st + 1` from our iterate call which is the row of values (as a NamedTuple) and the next "state", or in this case, row number.

Let's rewind just a bit to that `generate_namedtuple` call and take a deeper look. If you're not familiar or care about how generated functions work in julia, feel free to move on, but it's nothing too crazy that allows our iterate function to be type stable. Ok, here's the generated function definition:
```julia
function generate_namedtuple(::Type{NamedTuple{names, types}}, q) where {names, types}
    if @generated
        vals = Tuple(:(getvalue(q.ptr, $i, $(fieldtype(types, i)))) for i = 1:fieldcount(types))
        return :(NamedTuple{names, types}(($(vals...),)))
    else
        return NamedTuple{names, types}(Tuple(getvalue(q.ptr, i, fieldtype(types, i)) for i = 1:fieldcount(types)))
    end
end

Ok, let's take this apart, piece by piece. So we're passing in a NamedTuple type as the first argument, with type parameters names and types, which are "known" at compile time, allowing us to generate some code expressions based on those parameters. The if @generated block signals to the compiler that we want to take the types of the function arguments and return some kind of Expr object that will end up being compiled as the final method body. The else block provides a fallback for the compiler, in case the generated portion fails or in a runtime where the compiler isn't available, there'd still be a codepath that doesn't rely on dynamically compiling a new method body. In our case, return NamedTuple{names, types}(Tuple(getvalue(q.ptr, i, fieldtype(types, i)) for i = 1:fieldcount(types))) is this fallback, but is less than ideal because the result type is not inferrable. In the @generated block, however, we can dynamically generate the inner "tuple" of our NamedTuple by first generating vals = Tuple(:(getvalue(q.ptr, $i, $(fieldtype(types, i)))) for i = 1:fieldcount(types)), which is a tuple of Expr objects. Then we splice those expressions in to a NamedTuple constructor in a tuple. The resulting expression would look something like:

NamedTuple{(:a, :b), Tuple{Int, Int}((getvalue(q.ptr, 1, Int), getvalue(q.ptr, 2, Int)))

Which is a statically inferrable return type, which is important for our iterate function which may be called thousands of times while returning a result set.

Loading any table into a mysql table

Now our final remaining task is to enable "sink" support for a mysql database by allowing any Tables.jl interface input to be loaded into the database. The MySQL.jl package already has a MySQL.Stmt type to facilitate making prepared statements that get compiled once and can be re-used over and over by re-binding different parameters, so that seems like the best place to start:

function execute!(itr, stmt::Stmt)
    rows = Tables.rows(itr)
    state = iterate(rows)
    state === nothing && return stmt
    row, st = state
    sch = Tables.Schema(propertynames(row), nothing)
    binds = Vector{API.MYSQL_BIND}(undef, stmt.nparams)
    bindptr = pointer(binds)

    while true
        Tables.eachcolumn(sch, row) do val, col, nm
            binds[col] = bind(val)
        end
        API.mysql_stmt_bind_param(stmt.ptr, bindptr) == 0 || throw(MySQLStatementError(stmt.ptr))
        API.mysql_stmt_execute(stmt.ptr) == 0 || throw(MySQLStatementError(stmt.ptr))
        stmt.rows_affected += API.mysql_stmt_affected_rows(stmt.ptr)
        state = iterate(rows, st)
        state === nothing && break
        row, st = state
    end
    return stmt
end

Let's step through this function bit-by-bit: first we call Tables.rows(itr) to get a Row-iterator; databases naturally support inserting rows at a time via SQL statements like INSERT INTO tablename (column names...) VALUES (values...), so iterating over rows is natural. One thing we always want to consider in a sink is how to handle the case of an input table type having an "unknown schema". This is detected by a call to sch = Tables.schema(Tables.rows(x)) when sch === nothing. In those cases, we can't rely on knowing the column names or types upfront before we start iterating. For the mysql case, we're not extremely interested in knowing them anyway, since as input here, the user is passing stmt::MySQL.Stmt, in which they compiled their INSERT INTO ... SQL statement, and which included the column names and how many parameters to include. So in this case, we don't even make a call to Tables.schema(rows), but rather we just begin iterating rows. The machinery for binding values to SQL statement parameters is accomplished via the bind function, and storing the result in a binds buffer. The construct like

Tables.eachcolumn(sch, row) do val, col, nm
    binds[col] = bind(val)
end

utilizes a helpful function from Tables.jl: eachcolumn. eachcolumn can actually take a full Tables.Schema argument, or a half-schema like Tables.Schema(names, nothing) and it will essentially "unroll" the inner loop operating over columns. It makes a call to get the value for each column in a row, and returns the column's index and name. This makes it extremely convenient to bind each row value as parameters in our INSERT statement. Once the values are bound, we call down to the API C-functions to send the values off to the database. After that, we basically just iterate again and check to see if we're done yet.

Clone this wiki locally