diff --git a/NEWS.md b/NEWS.md index f6d51b1f3..b5ad2ff69 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,6 +16,9 @@ ## Bug fixes +* Correctly throw an error if negative number of rows is passed + to `first` or `last` + ([#3402](https://github.com/JuliaData/DataFrames.jl/pull/3402)) * Always use the default thread pool for multithreaded operations, instead of using the interactive thread pool when Julia was started with `-tM,N` with N > 0 diff --git a/src/abstractdataframe/abstractdataframe.jl b/src/abstractdataframe/abstractdataframe.jl index a812365ee..2b99ce462 100644 --- a/src/abstractdataframe/abstractdataframe.jl +++ b/src/abstractdataframe/abstractdataframe.jl @@ -558,14 +558,19 @@ Base.first(df::AbstractDataFrame) = df[1, :] first(df::AbstractDataFrame, n::Integer; view::Bool=false) Get a data frame with the `n` first rows of `df`. +Get all rows if `n` is greater than the number of rows in `df`. +Error if `n` is negative. If `view=false` a freshly allocated `DataFrame` is returned. If `view=true` then a `SubDataFrame` view into `df` is returned. $METADATA_FIXED """ -@inline Base.first(df::AbstractDataFrame, n::Integer; view::Bool=false) = - view ? Base.view(df, 1:min(n ,nrow(df)), :) : df[1:min(n, nrow(df)), :] +@inline function Base.first(df::AbstractDataFrame, n::Integer; view::Bool=false) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + r = min(n, nrow(df)) + return view ? Base.view(df, 1:r, :) : df[1:r, :] +end """ last(df::AbstractDataFrame) @@ -580,14 +585,19 @@ Base.last(df::AbstractDataFrame) = df[nrow(df), :] last(df::AbstractDataFrame, n::Integer; view::Bool=false) Get a data frame with the `n` last rows of `df`. +Get all rows if `n` is greater than the number of rows in `df`. +Error if `n` is negative. If `view=false` a freshly allocated `DataFrame` is returned. If `view=true` then a `SubDataFrame` view into `df` is returned. $METADATA_FIXED """ -@inline Base.last(df::AbstractDataFrame, n::Integer; view::Bool=false) = - view ? Base.view(df, max(1, nrow(df)-n+1):nrow(df), :) : df[max(1, nrow(df)-n+1):nrow(df), :] +@inline function Base.last(df::AbstractDataFrame, n::Integer; view::Bool=false) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + r = max(1, nrow(df) - n + 1) + return view ? Base.view(df, r:nrow(df), :) : df[r:nrow(df), :] +end """ describe(df::AbstractDataFrame; cols=:) diff --git a/test/dataframe.jl b/test/dataframe.jl index fbc2ec0ca..7efa1ca48 100644 --- a/test/dataframe.jl +++ b/test/dataframe.jl @@ -1180,10 +1180,16 @@ end @test_throws BoundsError first(DataFrame(x=[])) @test_throws BoundsError last(DataFrame(x=[])) - @test first(df, 6) == DataFrame(A=1:6) - @test first(df, 1) == DataFrame(A=1) - @test last(df, 6) == DataFrame(A=5:10) - @test last(df, 1) == DataFrame(A=10) + for v in (true, false) + @test first(df, 6, view=v) == DataFrame(A=1:6) + @test first(df, 1, view=v) == DataFrame(A=1) + @test first(df, 0, view=v) == DataFrame(A=Int[]) + @test_throws ArgumentError first(df, -1, view=v) + @test last(df, 6, view=v) == DataFrame(A=5:10) + @test last(df, 1, view=v) == DataFrame(A=10) + @test last(df, 0, view=v) == DataFrame(A=Int[]) + @test_throws ArgumentError last(df, -1, view=v) + end @inferred first(df, 6) @inferred last(df, 6)