Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
feat: added flag to shorten results
Browse files Browse the repository at this point in the history
  • Loading branch information
lukarade committed May 10, 2024
1 parent 3d1a71e commit b5fa598
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 644 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def _handle_unknown_call(self, call: Symbol, reason: Reasons) -> None:
# Deal with the case that the call calls a function parameter.
elif isinstance(call, Parameter):
self.call_graph_forest.get_graph(reason.id).reasons.unknown_calls[call.id] = UnknownProto(
symbol=call, origin=reason.function_scope.symbol
symbol=call, origin=reason.function_scope.symbol,
)

else:
self.call_graph_forest.get_graph(reason.id).reasons.unknown_calls[call.id] = UnknownProto(
symbol=call, origin=reason.function_scope.symbol
symbol=call, origin=reason.function_scope.symbol,
)

def _handle_cycles(self, removed_nodes: set[NodeID] | None = None) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,15 @@ def _get_impurity_result(reasons: Reasons) -> PurityResult:
impurity_reasons.add(
UnknownCall(
expression=UnknownFunctionCall(call=unknown_call.symbol.node),
origin=unknown_call.origin
origin=unknown_call.origin,
),
)
# Handle parameter calls
elif isinstance(unknown_call.symbol, Parameter):
impurity_reasons.add(
CallOfParameter(
expression=ParameterAccess(unknown_call.symbol),
origin=unknown_call.origin
origin=unknown_call.origin,
),
)
# Do not handle imported calls here since they are handled separately.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __hash__(self) -> int:
return hash(str(self))

@abstractmethod
def to_dict(self) -> dict[str, Any]:
def to_dict(self, shorten: bool = False) -> dict[str, Any]:
pass

@abstractmethod
Expand Down Expand Up @@ -105,7 +105,7 @@ def update(self, other: PurityResult | None) -> PurityResult:
def clone() -> Pure:
return Pure()

def to_dict(self) -> dict[str, Any]:
def to_dict(self, shorten: bool = False) -> dict[str, Any]: # noqa: ARG002
return {"purity": self.__class__.__name__}

def __hash__(self) -> int:
Expand Down Expand Up @@ -173,7 +173,7 @@ def update(self, other: PurityResult | None) -> PurityResult:
def clone(self) -> Impure:
return Impure(reasons=self.reasons.copy())

def to_dict(self) -> dict[str, Any]:
def to_dict(self, shorten: bool = False) -> dict[str, Any]:
seen = set()
non_local_variable_reads = []
non_local_variable_writes = []
Expand Down Expand Up @@ -202,16 +202,26 @@ def to_dict(self) -> dict[str, Any]:
parameter_calls.append(reason.to_dict())
case _:
raise TypeError(f"Unknown reason type: {reason}")

combined_reasons = {
"NonLocalVariableRead": non_local_variable_reads,
"NonLocalVariableWrite": non_local_variable_writes,
"FileRead": file_reads,
"FileWrite": file_writes,
"UnknownCall": unknown_calls,
"NativeCall": native_calls,
"CallOfParameter": parameter_calls,
}
if not shorten:
combined_reasons = {
"NonLocalVariableRead": non_local_variable_reads,
"NonLocalVariableWrite": non_local_variable_writes,
"FileRead": file_reads,
"FileWrite": file_writes,
"UnknownCall": unknown_calls,
"NativeCall": native_calls,
"CallOfParameter": parameter_calls,
}
else:
combined_reasons = {
"NonLocalVariableRead": len(non_local_variable_reads),
"NonLocalVariableWrite": len(non_local_variable_writes),
"FileRead": len(file_reads),
"FileWrite": len(file_writes),
"UnknownCall": len(unknown_calls),
"NativeCall": len(native_calls),
"CallOfParameter": len(parameter_calls),
}
return {
"purity": self.__class__.__name__,
"reasons": {
Expand Down Expand Up @@ -395,8 +405,9 @@ class UnknownProto(Unknown):
origin : Symbol | NodeID | None
The origin of the unknown call.
"""

symbol: Symbol | Reference
origin: Symbol | NodeID | None = field(default=None)
origin: Symbol | NodeID | None = field(default=None) # TODO: remove NodeID

def __hash__(self) -> int:
return hash(str(self))
Expand Down Expand Up @@ -643,15 +654,15 @@ class APIPurity:

purity_results: typing.ClassVar[dict[NodeID, dict[NodeID, PurityResult]]] = {}

def to_json_file(self, path: Path) -> None:
def to_json_file(self, path: Path, shorten: bool = False) -> None:
ensure_file_exists(path)
with path.open("w") as f:
json.dump(self.to_dict(), f, indent=2)
json.dump(self.to_dict(shorten), f, indent=2)

def to_dict(self) -> dict[str, Any]:
def to_dict(self, shorten: bool = False) -> dict[str, Any]:
return {
module_name.__str__(): {
function_id.__str__(): purity.to_dict()
function_id.__str__(): purity.to_dict(shorten)
for function_id, purity in purity_result.items()
if not purity.is_class
}
Expand Down
Loading

0 comments on commit b5fa598

Please sign in to comment.