Skip to content

Commit

Permalink
shoot me aljeblieft
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Dec 7, 2023
1 parent df3e2f2 commit f8b72de
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 41 deletions.
40 changes: 0 additions & 40 deletions src/generators/gradient_based/probe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,43 +31,3 @@ function ProbeGenerator(;
user_loss = Objectives.losses_catalogue[loss]
return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...)
end

"""
invalidation_rate(ce::AbstractCounterfactualExplanation)
Calculate the invalidation rate of a counterfactual explanation.
# Arguments
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the invalidation rate for.
- `kwargs`: Additional keyword arguments to pass to the function.
# Returns
The invalidation rate of the counterfactual explanation.
"""
function invalidation_rate(ce::AbstractCounterfactualExplanation)
if !hasfield(typeof(ce.convergence), :invalidation_rate)
@warn "Invalidation rate is only defined for InvalidationRateConvergence. Returning 0."
return 0.0
end

index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels))
f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target]
grad = []
for i in 1:length(ce.s′)
push!(
grad,
Flux.gradient(
() -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i],
Flux.params(ce.s′),
)[ce.s′],
)
end
gradᵀ = LinearAlgebra.transpose(grad)

identity_matrix = LinearAlgebra.Matrix{Float32}(I, length(grad), length(grad))
denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1]

normalized_gradient = f_loss / denominator
ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient)
return 1 - ϕ
end
1 change: 0 additions & 1 deletion src/objectives/Objectives.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Objectives

using ..CounterfactualExplanations
using ..CounterfactualExplanations.Generators
using Flux
using Flux.Losses
using ChainRulesCore
Expand Down
40 changes: 40 additions & 0 deletions src/objectives/loss_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,46 @@ function Flux.Losses.mse(ce::AbstractCounterfactualExplanation; kwargs...)
return loss
end

"""
invalidation_rate(ce::AbstractCounterfactualExplanation)
Calculate the invalidation rate of a counterfactual explanation.
# Arguments
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the invalidation rate for.
- `kwargs`: Additional keyword arguments to pass to the function.
# Returns
The invalidation rate of the counterfactual explanation.
"""
function invalidation_rate(ce::AbstractCounterfactualExplanation)
if !hasfield(typeof(ce.convergence), :invalidation_rate)
@warn "Invalidation rate is only defined for InvalidationRateConvergence. Returning 0."
return 0.0
end

index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels))
f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target]
grad = []
for i in 1:length(ce.s′)
push!(
grad,
Flux.gradient(
() -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i],
Flux.params(ce.s′),
)[ce.s′],
)
end
gradᵀ = LinearAlgebra.transpose(grad)

identity_matrix = LinearAlgebra.Matrix{Float32}(I, length(grad), length(grad))
denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1]

normalized_gradient = f_loss / denominator
ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient)
return 1 - ϕ
end

"""
hinge_loss_ir(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)
Expand Down

0 comments on commit f8b72de

Please sign in to comment.