From bfb1a76defae8358b4596a8b042b2757956c0cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Sat, 6 Jan 2024 12:00:35 +0100 Subject: [PATCH] Fix eachrow and eachcol indexing with CartesianIndex --- NEWS.md | 3 +++ src/abstractdataframe/iteration.jl | 3 +++ test/iteration.jl | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/NEWS.md b/NEWS.md index b5ad2ff69..a454a7793 100644 --- a/NEWS.md +++ b/NEWS.md @@ -26,6 +26,9 @@ * Correctly return `Bool[]` in the `nonunique` function applied to a data frame with a pulled column that has zero levels in the pool ([#3393](https://github.com/JuliaData/DataFrames.jl/pull/3393)) +* Correctly index `eachrow` and `eachcol` with `CartesianIndex` + ([#3413](https://github.com/JuliaData/DataFrames.jl/issues/3413)) + # DataFrames.jl v1.6.1 Release Notes diff --git a/src/abstractdataframe/iteration.jl b/src/abstractdataframe/iteration.jl index c81228fb1..6f6abc620 100644 --- a/src/abstractdataframe/iteration.jl +++ b/src/abstractdataframe/iteration.jl @@ -81,6 +81,7 @@ Base.IndexStyle(::Type{<:DataFrameRows}) = Base.IndexLinear() Base.size(itr::DataFrameRows) = (size(parent(itr), 1), ) Base.@propagate_inbounds Base.getindex(itr::DataFrameRows, i::Int) = parent(itr)[i, :] +Base.@propagate_inbounds Base.getindex(itr::DataFrameRows, i::CartesianIndex{1}) = itr[i[1]] Base.@propagate_inbounds Base.getindex(itr::DataFrameRows, idx) = eachrow(@view parent(itr)[idx isa AbstractVector && !(eltype(idx) <: Bool) ? copy(idx) : idx, :]) @@ -263,6 +264,8 @@ Base.iterate(itr::DataFrameColumns, i::Integer=1) = i <= length(itr) ? (itr[i], i + 1) : nothing Base.@propagate_inbounds Base.getindex(itr::DataFrameColumns, idx::ColumnIndex) = parent(itr)[!, idx] +Base.@propagate_inbounds Base.getindex(itr::DataFrameColumns, idx::CartesianIndex{1}) = + itr[idx[1]] Base.@propagate_inbounds Base.getindex(itr::DataFrameColumns, idx::MultiColumnIndex) = eachcol(parent(itr)[!, idx]) Base.:(==)(itr1::DataFrameColumns, itr2::DataFrameColumns) = diff --git a/test/iteration.jl b/test/iteration.jl index 4c1b9d0d1..249677a02 100644 --- a/test/iteration.jl +++ b/test/iteration.jl @@ -15,6 +15,8 @@ using Test, DataFrames @test sprint(summary, eachrow(df)) == "2-element DataFrameRows" @test Base.IndexStyle(eachrow(df)) == IndexLinear() @test eachrow(df)[1] == DataFrameRow(df, 1, :) + @test eachrow(df)[CartesianIndex(1)] == DataFrameRow(df, 1, :) + @test_throws MethodError eachrow(df)[CartesianIndex(1, 1)] @test collect(eachrow(df)) isa Vector{<:DataFrameRow} @test eltype(eachrow(df)) <: DataFrameRow for row in eachrow(df) @@ -35,6 +37,8 @@ using Test, DataFrames @test_throws ArgumentError size(eachcol(df), 2) @test_throws ArgumentError size(eachcol(df), 0) @test eachcol(df)[1] == df[:, 1] + @test eachcol(df)[CartesianIndex(1)] == df[:, 1] + @test_throws MethodError eachcol(df)[CartesianIndex(1, 1)] @test eachcol(df)[:A] === df[!, :A] @test eachcol(df)[All()] == eachcol(df) @test eachcol(df)[Cols(:)] == eachcol(df)