Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The use of NamedTuple in Core.kwcall prevents specialization of keyword arguments #54661

Open
MilesCranmer opened this issue Jun 3, 2024 · 17 comments
Labels
keyword arguments f(x; keyword=arguments)

Comments

@MilesCranmer
Copy link
Member

MilesCranmer commented Jun 3, 2024

The Julia docs here say that to force specialization on types, you should use ::Type{T}. However, this does not work for keyword arguments:

julia> f(; t::Type{T}) where {T} = T
f (generic function with 1 method)

julia> Test.@inferred f(t=Float32)
ERROR: return type Type{Float32} does not match inferred return type DataType

The reason is because Core.kwcall is passed keyword arguments as a NamedTuple, and this only encodes DataType rather than Type{T} in its parameters. For this reason some of the Julia internals use _stable_typeof as a workaround version of of typeof which maps Float32 to Type{Float32} rather than DataType.

I would like there to be a similar mechanism for Core.kwcall so that keyword arguments get the same specialization rules as arguments.

For example, one could use Core.Typeof in constructing the NamedTuple passed to Core.kwcall, so that it preserves the specialization of the calling function. You can actually have the NamedTuple retain this information:

julia> @NamedTuple{t::Type{Float32}}((Float32,))
@NamedTuple{t::Type{Float32}}((Float32,))

x-ref MilesCranmer/DispatchDoctor.jl#27

@thofma
Copy link
Contributor

thofma commented Jun 4, 2024

In #39866 I was told that everything is fine, but the tools are lying. (I think this might a be dup of #39866.)

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 4, 2024

It looks like it's a duplicate but let's keep this one open.

@vtjnash I think this is a real thing? It's because of the fact that keywords are passed via Core.kwcall with a NamedTuple:

julia> typeof((; t=Float32))
@NamedTuple{t::DataType}

This means that there is no way to force specialization on keywords using the same tricks possible with args.

The solution is to apply a _stable_typeof for kwcall.

Of course you might find that the compiler will inline everything here (like with the example given by @vtjnash — the constant propagation goes through the kwcall), but the inference on the keyword call itself still fails, even if you force it.

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 4, 2024

So if I lower f(; t::Type{T}), it turns into this:

julia> f(; t::Type{T}) where {T} = T
f (generic function with 1 method)

julia> @code_lowered f(t=Float32)
CodeInfo(
1 ─       Core.NewvarNode(:(@_4))
│   %2  = Core.isdefined(@_2, :t)
└──       goto #3 if not %2
2@_4 = Core.getfield(@_2, :t)
└──       goto #4
3%6  = Core.UndefKeywordError(:t)
└──       @_4 = Core.throw(%6)
4%8  = @_4
│         t = %8%10 = (:t,)
│   %11 = Core.apply_type(Core.NamedTuple, %10)
│   %12 = Base.structdiff(@_2, %11)
│   %13 = Base.pairs(%12)
│   %14 = Base.isempty(%13)
└──       goto #6 if not %14
5 ─       goto #7
6 ─       Base.kwerr(@_2, @_3)
7%18 = Main.:(var"#f#1")(t, @_3)
└──       return %18
)

So the line

Core.apply_type(Core.NamedTuple, %10)

Is where the issue stems from. To propagate the user's specialization, I think we might want to have a way to lower to a stable form of this so that specialization of the keywords propagates.

I tried to understand the lowering in Julia. My understanding is that this line:

julia/src/julia-syntax.scm

Lines 621 to 622 in 13635e1

`(call (top structdiff) ,kw (curly (core NamedTuple)
(tuple ,@(map quotify keynames))))))))
is the one responsible? Basically it is calling NamedTuple on the keynames which are the keyword names?

So I guess if instead of wrapping with Core.NamedTuple, it did something like:

Core.NamedTuple{keynames,Tuple{map(Core.Typeof,keyvars)...}}(keyvars)

then it should fix this, because Core.Typeof maps Float32 to Type{Float32} rather than DataType. So, for example:

julia> t = Float64
Float64

julia> keynames = (:t,)
(:t,)

julia> NamedTuple{keynames,Tuple{map(Core.Typeof,(t,))...}}((t,))
@NamedTuple{t::Type{Float64}}((Float64,))

Which keeps the user's specialization into the NamedTuple, and therefore propagates it to the kwcall. Which would therefore make the lowered form of f(; t::Type{T}) where {T} = T match the user-specified specialization.

Thoughts @vtjnash? Sorry if I'm completely misreading things..

@JeffBezanson
Copy link
Member

I think this is maybe mostly an artifact of how @inferred and @code_typed work; if you look at a call site inside a function you get:

julia> f(; t::Type{T}) where {T} = T
julia> g() = f(t=Float32)
julia> @code_typed g()
CodeInfo(
1 ─     return Float32
) => Type{Float32}

So in context we are able to infer it.

So the line

Core.apply_type(Core.NamedTuple, %10)

Is where the issue stems from.

I don't think that's correct --- the problem is that at the keyword arg call site we form a NamedTuple{::DataType}, so that's where the type is "hidden", but in practice inference seems to be good enough to get around this.

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 7, 2024

I think this is maybe mostly an artifact of how @inferred and @code_typed work; if you look at a call site inside a function you get:

julia> f(; t::Type{T}) where {T} = T
julia> g() = f(t=Float32)
julia> @code_typed g()
CodeInfo(
1 ─     return Float32
) => Type{Float32}

So in context we are able to infer it.

I think this is only because the compiler is inlining f, right? If there was no inlining then I’m not sure how the type information could propagate due to the use of NamedTuple in kwcall (as you also noted)

but in practice inference seems to be good enough to get around this.

The reason for my issue is precisely this — I have a very complex function in SymbolicRegression.jl where ; loss_type::Type{T} is passed as as a keyword. Here, inference fails, because the compiler did not inline the function (for good reason), and therefore the T was not inferred and known within the function scope.

So I am wondering if ::Type{T} on kwargs could force specialisation similar to args? Such as with a Core.Typeof version of NamedTuple for use in Core.kwcall?

@JeffBezanson
Copy link
Member

Could you make a reduced example? Use @noinline if necessary. I'm not sure inlining makes a difference, since we can propagate partially-const structs through inference, but I'm curious to see an example.

If the problem is the NamedTuple formed at the call site, then ::Type{T} on an argument can't change it; we'd have to form more specific NamedTuples at all kwarg call sites. I'm not sure what the impact of that would be. It might be totally reasonable.

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 7, 2024

The specific function I first ran into this issue with was this one: https://github.com/MilesCranmer/SymbolicRegression.jl/blob/ea03242d099aa189cad3612291bcaf676d77451c/src/Dataset.jl#L98.

If you pass Dataset(randn(3, 32), randn(32); loss_type=Float64), it will fail the inference. Test.@inferred confirmed this.

I have since adjusted it so that loss_type is an argument instead, and this fixed the issue entirely. But I found it to be a sharp edge that it was only as an argument where inference worked, despite my use of ::Type{T} in keyword form. So would be nice if this worked.

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 10, 2024

My general workaround for anyone running into this inference bug is to pass types wrapped in a Val:

julia> f(; t::Val{T}) where {T} = T
f (generic function with 1 method)

julia> Test.@inferred f(t=Val(Float32))
Float32

This does not have the inference issue because Core.kwcall is passed a @NamedTuple{t::Val{T}}, whereas for ::Type{T}, Core.kwcall receives a @NamedTuple{t::DataType}, hence the issue.

This of course makes sense because Float32 is an instance of DataType, rather than being a type itself. But what does not make sense to me is the asymmetry on the specialization of ::Type{T} on arguments vs keyword arguments – so I think this deserves a patch.

If not fixable it would be good to least document it on https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing.

Basically, Julia will still avoid specializing on types in keyword arguments even if marked with ::Type{T}. But this advice does not apply to the other two special cases: functions or Val, because Core.kwcall is aware of them from the NamedTuple:

julia> (; f = +)
(f = +,)

julia> (; f = +) |> typeof
@NamedTuple{f::typeof(+)}

julia> (; v = Val(1))
(v = Val{1}(),)

julia> (; v = Val(1)) |> typeof
@NamedTuple{v::Val{1}}

julia> (; t = Float32)
(t = Float32,)

julia> (; t = Float32) |> typeof
@NamedTuple{t::DataType}

Which means there is no asymmetry for those other two special cases, it's literally just types.

@aviatesk
Copy link
Member

aviatesk commented Jun 10, 2024

This issue is indeed due to the implementations of code_typed and @inferred, and I don't think we should extend @NamedTuple (or Julia's type system more broadly) just to solve this problem. Specifically, I think we should revive #29261 so that certain reflection utilities can start inference with the extended lattice information given at the entry point.

@MilesCranmer
Copy link
Member Author

@aviatesk I’m not sure I understand, why do you say it is only an interactive tools problem? Note that even Base.promote_op fails the inference.

I think this problem be solved without extending the type system. All that needs to happen is for a tiny change in the lowering of a keyword call. Normally it is Core.apply_type(Core.NamedTuple, %10) which is equivalent to:

NamedTuple{keynames,Tuple{map(typeof,keyvars)…}}(keyvars)

This should be changed to

NamedTuple{keynames,Tuple{map(Core.Typeof,keyvars)...}}(keyvars)

Which will cause the NamedTuple to pick up the ::Type{T} rather than DataType.

Again, and this is only for lowering to Core.kwcall. It doesn’t affect the type system itself.

@KristofferC
Copy link
Member

KristofferC commented Jun 10, 2024

Note that even Base.promote_op fails the inference.

As far as I understand, you check this inference failure with @inferred ("Test.@inferred confirmed this."). But then there have been comments saying:

I think this is maybe mostly an artifact of how @inferred and @code_typed work

This issue is indeed due to the implementations of code_typed and @inferred

So to me (just reading through the conversation) it seems that this has been missed? Or are you using some other way to determine the "inference bug"?

@MilesCranmer
Copy link
Member Author

Sorry if I misunderstand; but to clarify, I am using Base.promote_op directly and seeing the same issue with the lack of keyword specialisation.

@KristofferC
Copy link
Member

Could you show an explicit example of that?

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 10, 2024

Ok, see below.

First, just some context – I first saw this deep in the call stack of SymbolicRegression.jl. Changing to arguments made the inference issue go away. Hence my interest in fixing this.

Also, just to note – obviously if you do:

julia> f(; t::Type{T}) where {T} = T
f (generic function with 1 method)

julia> g(::Type{T}) where {T} = f(; t=T)
g (generic function with 1 method)

julia> Core.kwcall((; t=Float32), f)
Float32

julia> Base.promote_op(Core.kwcall, typeof((; t=Float32)), typeof(f))
DataType

you will get the failed inference. But I'm assuming you want the promote_op to receive a fully-realised type so LLVM has a chance to do its magic. So I do that below –

Basically it's really hard to prevent the compiler from doing some amount of inlining in a toy example (such as the simple example of just f(; t::Type{T}) where {T} = T; g(t::Type{T}) where {T} = f(; t) – obviously LLVM will just inline that). So, to prevent this from happening, and to reveal the actual specialisation going on, I use both @noinline as well as a recursion to emulate a real-world case where you wouldn't have inlining.

First, here is what happens with keywords. I generate 8 functions (random big number, I'm not sure the exact dividing line) which randomly call each other in a recursion up to a stack depth of 10. Now LLVM has no chance of inlining!

gs = [gensym("g") for _ in 1:8]

for g in gs
    @eval @noinline function $g(; t::Type{T}, i) where {T}
        k = rand(1:8)
        if i > 10
            return one(T)
        elseif k == 1
            return $(gs[1])(; t, i=i+1)
        elseif k == 2
            return $(gs[2])(; t, i=i+1)
        elseif k == 3
            return $(gs[3])(; t, i=i+1)
        elseif k == 4
            return $(gs[4])(; t, i=i+1)
        elseif k == 5
            return $(gs[5])(; t, i=i+1)
        elseif k == 6
            return $(gs[6])(; t, i=i+1)
        elseif k == 7
            return $(gs[7])(; t, i=i+1)
        else
            return $(gs[8])(; t, i=i+1)
        end
    end
end

@eval @noinline f(t::Type{T}) where {T} = $(gs[1])(; t, i=1)

When I run promote_op, I get:

julia> Base.promote_op(f, Type{Float64})
Any

Now, if I instead try the exact same thing, but with regular arguments:

gs = [gensym("g") for _ in 1:8]
for g in gs
    @eval @noinline function $g(t::Type{T}, i) where {T}
        k = rand(1:8)
        if i > 10
            return one(T)
        elseif k == 1
            return $(gs[1])(t, i+1)
        elseif k == 2
            return $(gs[2])(t, i+1)
        elseif k == 3
            return $(gs[3])(t, i+1)
        elseif k == 4
            return $(gs[4])(t, i+1)
        elseif k == 5
            return $(gs[5])(t, i+1)
        elseif k == 6
            return $(gs[6])(t, i+1)
        elseif k == 7
            return $(gs[7])(t, i+1)
        else
            return $(gs[8])(t, i+1)
        end
    end
end

@eval @noinline f(t::Type{T}) where {T} = $(gs[1])(t, 1)

and we run it, I get a correct inference:

julia> Base.promote_op(f, Type{Float64})
Float64

So you can see there is some difference in the behavior of specialisation between keyword arguments and regular arguments. At least, I think this shows it? Obviously this is a contrived example, but there's probably a smaller MWE out there one could pull out. For me I encountered it in that big Dataset constructor I mentioned above.

@aviatesk
Copy link
Member

This is indeed an inference issue related to kwfunc, but it is a different problem from the one originally raised in this issue.
Regarding this inference issue, when a call graph containing recursion includes a kwfunc, it becomes difficult to resolve the recursion, which can result in inference failure. This is an issue that should ideally be resolved on its own, but it should probably be tracked in a separate issue.

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Jun 10, 2024

This is indeed an inference issue related to kwfunc, but it is a different problem from the one originally raised in this issue. Regarding this inference issue, when a call graph containing recursion includes a kwfunc, it becomes difficult to resolve the recursion, which can result in inference failure. This is an issue that should ideally be resolved on its own, but it should probably be tracked in a separate issue.

Hm, I wonder if this is the cause of the issue I had meant to report? I have a slight feeling that all of the real-world instances where I have seen this symptom appear were from recursive calls.

The stuff I was posting about Core.kwcall and NamedTuple is from my attempt to do detective work on this (and because Test.@inferred – and every other tool at my disposal – flags it), but perhaps that wasn't actually the issue here?

For the record I really have no idea how Core.kwcall(NamedTuple(kws), f) is even supposed to infer correctly with ::Type{T} without aggressive inlining... Can someone please explain why that works? I have been told that @code_warntype lies, etc... But why? Is there some postprocessing that fills in the types? The lowering shows Core.kwcall(NamedTuple(kws), f). I'm just confused about this. (Please spare no technical detail if possible... I'm really just curious :))

@MilesCranmer
Copy link
Member Author

Hi all,

Just following up on this. I am trying to understand this deeply, not only for my own curiosity, but also so that I can improve DispatchDoctor.jl which helps detect type instability. So if you can help me understand the reason behind this, I will go make a patch to DispatchDoctor to improve its instability detection.

Basically can someone please confirm to me that

julia> f(; t::Type{T}) where {T} = T

is type stable, even if not inlined?

I think I know how to patch this in DispatchDoctor.jl so it can avoid flagging this as an instability – I will use the example I gave above:

julia> t = Float64;

julia> keynames = (:t,);

julia> NamedTuple{keynames,Tuple{map(Core.Typeof,(t,))...}}((t,))
@NamedTuple{t::Type{Float64}}((Float64,))

If I pass this named tuple to Core.kwcall, it should correct the inference for this example.

@nsajko nsajko added the keyword arguments f(x; keyword=arguments) label Oct 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keyword arguments f(x; keyword=arguments)
Projects
None yet
Development

No branches or pull requests

6 participants