Skip to content

Commit

Permalink
Add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
benoit-cty committed Sep 12, 2023
1 parent 7f1406e commit 9e05a0d
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions openfisca_survey_manager/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import re
from typing import Dict, List

from openfisca_core import periods
from openfisca_core.simulations import Simulation
Expand Down Expand Up @@ -562,23 +563,42 @@ class SecretViolationError(Exception):
def compute_winners_loosers(
simulation,
baseline_simulation,
variable = None,
variable:str,
filter_by = None,
period = None,
absolute_minimal_detected_variation = 0,
relative_minimal_detected_variation = .01,
observations_threshold = None,
weighted = True,
alternative_weights = None,
absolute_minimal_detected_variation:float = 0,
relative_minimal_detected_variation:float = .01,
observations_threshold:int = None,
weighted:bool = True,
alternative_weights:List = None,
filtering_variable_by_entity = None,
):
) -> Dict[str, int]:
"""
Compute the number of winners and loosers for a given variable
Args:
simulation: The OpenFisca simulation object
baseline_simulation: The OpenFisca simulation to compare
variable: The variable to be compared
filter_by: The variable to be used as a filter
period: The period of the simulation
absolute_minimal_detected_variation: Absolute minimal variation to be detected, in ratio. Ie 0.5 means 5% of variation wont be counted.
relative_minimal_detected_variation: Relative minimal variation to be detected, in ratio.
observations_threshold: Number of observations needed to avoid a statistical secret violation. Defaults to None.
weighted: Whether to use weights
alternative_weights: The weights to be used
filtering_variable_by_entity: The variable to be used as a filter
Returns:
A dictionary
"""
weight_variable_by_entity = simulation.weight_variable_by_entity
entity_key = baseline_simulation.tax_benefit_system.variables[variable].entity.key

# Get the results of the simulation
after = simulation.adaptative_calculate_variable(variable, period = period)
before = baseline_simulation.adaptative_calculate_variable(variable, period = period)

# Filter if needed
if filtering_variable_by_entity is not None:
if filter_by is None:
filter_by = filtering_variable_by_entity.get(entity_key)
Expand All @@ -591,6 +611,7 @@ def compute_winners_loosers(
after = after[filter_dummy].copy()
before = before[filter_dummy].copy()

# Define weights
weight = np.ones(len(after))
if weighted:
if alternative_weights is not None:
Expand All @@ -601,10 +622,9 @@ def compute_winners_loosers(
else:
log.warn('There is no weight variable for entity {} nor alternative weights. Switch to unweighted'.format(entity_key))

# Compute the weigthed number of zeros or non zeros
value_by_simulation = dict(after = after, before = before)

stats_by_simulation = dict()

for simulation_prefix, value in value_by_simulation.items():
stats = dict()
stats["count_zero"] = (
Expand All @@ -617,6 +637,7 @@ def compute_winners_loosers(
stats_by_simulation[simulation_prefix] = stats
del stats

# Compute the number of entity above or below after
after_value = after
before_value = before

Expand All @@ -631,12 +652,14 @@ def compute_winners_loosers(
after_value < -absolute_minimal_detected_variation
)[almost_zero_before * (after_value < 0)]

# Check if there is a secret violation, without weights
if observations_threshold is not None:
not_legit_below = (below_after.sum() < observations_threshold) & (below_after.sum() > 0)
not_legit_above = (above_after.sum() < observations_threshold) & (above_after.sum() > 0)
if not_legit_below | not_legit_above:
raise SecretViolationError("Not enough observations involved")

# Apply weights
above_after_count = (above_after.astype("float64") * weight.astype("float64")).sum()
below_after_count = (below_after.astype("float64") * weight.astype("float64")).sum()
total = sum(weight)
Expand Down

0 comments on commit 9e05a0d

Please sign in to comment.