Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast.broadcast_shape for BlockedUnitRanges fails inference #310

Open
charleskawczynski opened this issue Sep 19, 2023 · 12 comments · May be fixed by #312 or #313
Open

Broadcast.broadcast_shape for BlockedUnitRanges fails inference #310

charleskawczynski opened this issue Sep 19, 2023 · 12 comments · May be fixed by #312 or #313

Comments

@charleskawczynski
Copy link

charleskawczynski commented Sep 19, 2023

MWE:

using BlockArrays
shape1 = (BlockArrays._BlockedUnitRange((2,)),);
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
Base.Broadcast._bcs(shape1, shape2)
# @code_warntype Base.Broadcast._bcs1(shape1[1], shape2[1]) # lower level, but internals
@code_warntype Base.Broadcast.broadcast_shape(shape1, shape2)
@charleskawczynski charleskawczynski changed the title Broadcast BCs for BlockedUnitRange objects fail inference Broadcast.broadcast_shape for BlockedUnitRanges fail inference Sep 19, 2023
@charleskawczynski charleskawczynski changed the title Broadcast.broadcast_shape for BlockedUnitRanges fail inference Broadcast.broadcast_shape for BlockedUnitRanges fails inference Sep 19, 2023
@jishnub
Copy link
Member

jishnub commented Sep 20, 2023

This comes from

Base.Broadcast.axistype(a::T, b::T) where T<:BlockedUnitRange = length(b) == 1 ? a : combine_blockaxes(a, b)

Perhaps we may remove the special-casing, as the performance impact in the 1-term union will be minimal

@charleskawczynski
Copy link
Author

charleskawczynski commented Sep 20, 2023

Yeah, I tried fixing this. I'll open the PR for convenience (#312). First, I thought that the core issue was that union is used in sortedunion, and union does not preserve tuple types:

julia> union((1,), (2,))
2-element Vector{Int64}:
 1
 2

However, even after fixing that in the PR (which borrows some functions in TupleTools.jl and defines union on Tuples to return a Tuple), the result is still type unstable because the tuple length depends on the values in the tuple:

Screen Shot 2023-09-20 at 9 10 03 AM

So, while this could fix the type preservation, it won't actually fix the type instability. If you have ideas about how to make Broadcast.broadcast_shape type stable / inferrable, that'd be great

@jishnub
Copy link
Member

jishnub commented Sep 21, 2023

I think it's better for this to return a vector, as the non-trivial branch does. Since unique works on values, the length can't be known at compile time. I wonder if we may just remove the if-else here? I haven't checked if this changes the results

@charleskawczynski
Copy link
Author

charleskawczynski commented Sep 21, 2023

Yeah, that does fix it. I guess it's worth it. The general case of calling combine_blockaxes is either going to be type unstable in the case of Tuples (leading to inference triggers), or allocating heap allocating from union(a,b). Right now, it's a mixture, so it allocates and it's type-unstable.

@charleskawczynski
Copy link
Author

I'll update the PR

@charleskawczynski
Copy link
Author

Ok, the PR is updated.

@charleskawczynski
Copy link
Author

charleskawczynski commented Sep 21, 2023

Thanks for the tip @jishnub!

@charleskawczynski charleskawczynski linked a pull request Sep 21, 2023 that will close this issue
@charleskawczynski
Copy link
Author

Actually, I prefer #313 if that's okay 🙂

@charleskawczynski
Copy link
Author

So, it seems that length(b) == 1 ? a is needed for upstream tests to pass.

I updated #313, and basically specialized it so that the tests pass, but I have a feeling that I'm introducing inconsistent behavior between blocklasts()::Tuple and blocklasts()::Vector cases.

I think one potential solution would be to put the length of BlockedUnitRange in the type space:

struct BlockedUnitRange{CS,L} <: AbstractUnitRange{Int}
    first::Int
    lasts::CS
    global _BlockedUnitRange(f, cs::CS) where CS = new{CS, len(f, cs)}(f, cs)
end

len(f, l) = isempty(l) ? 0 : Integer(last(l)-f+1)
...
length(a::BlockedUnitRange{CS, L}) where {CS, L} = L

which will make the call to length (in combine_blockaxes) compile-time known, and remove the type instability. However, that may cause other issues. And I'm now seeing that things like DefaultBlockAxis may make this change more intrusive.

Thoughts? jishnub

@jishnub
Copy link
Member

jishnub commented Oct 12, 2023

I wonder if a simple workaround might be to convert the return value of axistype to a Vector?

@dlfivefifty
Copy link
Member

I wonder if a simple workaround might be to convert the return value of axistype to a Vector?

Absolutely not: one of the usages downstream are infinite dimensional block array, and even without that we want to support allocation-free axes when the block sizes are Fill.

Note this really is an issue inherited with Base's broadcasting: they never should have supported degenerate sizes like randn(5,1) .* randn(5,6), and restricted broadcasting to randn(5) .* randn(5,6). Similarly only support randn(5,6) .* transpose(randn(6)).

We could potentially disallow degenerate broadcasting for blocked arrays. We'd have to see what tests fail downstream.

@charleskawczynski
Copy link
Author

Agreed, the main reason for fixing this issue is to avoid allocations. So ideally we have a solution that is fully inferred and stack allocated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants