# -*- coding: utf-8 -*-
import ast
from typing import Callable, Optional, Sequence, Union, overload
from typing_extensions import Literal
[docs]
class Predicate:
TRUE: "Predicate"
FALSE: "Predicate"
@overload
def __init__(
self,
condition: Callable[[ast.AST], bool],
use_raw_node_id: Literal[False],
static: bool = True,
) -> None: ...
@overload
def __init__(
self,
condition: Callable[[int], bool],
use_raw_node_id: Literal[True],
static: bool = False,
) -> None: ...
@overload
def __init__(
self,
condition: Callable[..., bool],
use_raw_node_id: bool = False,
static: bool = False,
) -> None: ...
def __init__(
self,
condition: Callable[..., bool],
use_raw_node_id: bool = False,
static: bool = False,
) -> None:
self.condition = condition
self.use_raw_node_id = use_raw_node_id
self.static = static
def clone(self) -> "Predicate":
return self.__class__(self.condition, self.use_raw_node_id, self.static)
def __call__(self, node: Union[ast.AST, int]) -> bool:
node_or_id = (
id(node) if self.use_raw_node_id and isinstance(node, ast.AST) else node
)
return self.condition(node_or_id) # type: ignore
def dynamic_call(self, node: Union[ast.AST, int]) -> bool:
return True if self.static else self(node)
Predicate.TRUE = Predicate(lambda *_: True)
Predicate.FALSE = Predicate(lambda *_: False)
[docs]
class CompositePredicate(Predicate):
def __init__(self, base_predicates: Sequence[Predicate], reducer=any) -> None:
self.base_predicates = list(base_predicates)
self.dynamic_base_predicates = [
pred for pred in base_predicates if not pred.static
]
self.static = len(self.dynamic_base_predicates) == 0
self.use_raw_node_id = all(pred.use_raw_node_id for pred in base_predicates)
self.reducer = reducer
def __call__(
self,
node: Union[ast.AST, int],
predicates: Optional[Sequence[Predicate]] = None,
) -> bool:
predicates = self.base_predicates if predicates is None else predicates
assert len(predicates) > 0
return self.reducer(pred(node) for pred in predicates)
def dynamic_call(self, node: Union[ast.AST, int]) -> bool:
return (
True if self.static else self(node, predicates=self.dynamic_base_predicates)
)
@classmethod
def _create(cls, base_predicates: Sequence[Predicate], reducer) -> Predicate:
assert len(base_predicates) > 0
return cls(base_predicates, reducer=reducer)
@classmethod
def any(cls, base_predicates: Sequence[Predicate]) -> Predicate:
if len(base_predicates) == 0 or any(
pred is Predicate.TRUE for pred in base_predicates
):
return Predicate.TRUE
if all(pred is Predicate.FALSE for pred in base_predicates):
return Predicate.FALSE
return cls._create(base_predicates, reducer=any)
@classmethod
def all(cls, base_predicates: Sequence[Predicate]) -> Predicate:
if any(pred is Predicate.FALSE for pred in base_predicates):
return Predicate.FALSE
if all(pred is Predicate.TRUE for pred in base_predicates):
return Predicate.TRUE
return cls._create(base_predicates, reducer=all)