diff --git a/src/generators/gradient_based/probe.jl b/src/generators/gradient_based/probe.jl index f51f663d0..13d8df6fe 100644 --- a/src/generators/gradient_based/probe.jl +++ b/src/generators/gradient_based/probe.jl @@ -31,43 +31,3 @@ function ProbeGenerator(; user_loss = Objectives.losses_catalogue[loss] return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...) end - -""" - invalidation_rate(ce::AbstractCounterfactualExplanation) - -Calculate the invalidation rate of a counterfactual explanation. - -# Arguments -- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the invalidation rate for. -- `kwargs`: Additional keyword arguments to pass to the function. - -# Returns -The invalidation rate of the counterfactual explanation. -""" -function invalidation_rate(ce::AbstractCounterfactualExplanation) - if !hasfield(typeof(ce.convergence), :invalidation_rate) - @warn "Invalidation rate is only defined for InvalidationRateConvergence. Returning 0." - return 0.0 - end - - index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels)) - f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target] - grad = [] - for i in 1:length(ce.s′) - push!( - grad, - Flux.gradient( - () -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i], - Flux.params(ce.s′), - )[ce.s′], - ) - end - gradᵀ = LinearAlgebra.transpose(grad) - - identity_matrix = LinearAlgebra.Matrix{Float32}(I, length(grad), length(grad)) - denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1] - - normalized_gradient = f_loss / denominator - ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient) - return 1 - ϕ -end diff --git a/src/objectives/Objectives.jl b/src/objectives/Objectives.jl index 50d16c0b3..a3238c182 100644 --- a/src/objectives/Objectives.jl +++ b/src/objectives/Objectives.jl @@ -1,7 +1,6 @@ module Objectives using ..CounterfactualExplanations -using ..CounterfactualExplanations.Generators using Flux using Flux.Losses using ChainRulesCore diff --git a/src/objectives/loss_functions.jl b/src/objectives/loss_functions.jl index 9f13b9daf..d0c265aaa 100644 --- a/src/objectives/loss_functions.jl +++ b/src/objectives/loss_functions.jl @@ -45,6 +45,46 @@ function Flux.Losses.mse(ce::AbstractCounterfactualExplanation; kwargs...) return loss end +""" + invalidation_rate(ce::AbstractCounterfactualExplanation) + +Calculate the invalidation rate of a counterfactual explanation. + +# Arguments +- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the invalidation rate for. +- `kwargs`: Additional keyword arguments to pass to the function. + +# Returns +The invalidation rate of the counterfactual explanation. +""" +function invalidation_rate(ce::AbstractCounterfactualExplanation) + if !hasfield(typeof(ce.convergence), :invalidation_rate) + @warn "Invalidation rate is only defined for InvalidationRateConvergence. Returning 0." + return 0.0 + end + + index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels)) + f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target] + grad = [] + for i in 1:length(ce.s′) + push!( + grad, + Flux.gradient( + () -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i], + Flux.params(ce.s′), + )[ce.s′], + ) + end + gradᵀ = LinearAlgebra.transpose(grad) + + identity_matrix = LinearAlgebra.Matrix{Float32}(I, length(grad), length(grad)) + denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1] + + normalized_gradient = f_loss / denominator + ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient) + return 1 - ϕ +end + """ hinge_loss_ir(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)