Source code for pyccolo.predicate

# -*- coding: utf-8 -*-
import ast
from typing import Callable, Optional, Sequence, Union, overload

from typing_extensions import Literal


[docs] class Predicate: TRUE: "Predicate" FALSE: "Predicate" @overload def __init__( self, condition: Callable[[ast.AST], bool], use_raw_node_id: Literal[False], static: bool = True, ) -> None: ... @overload def __init__( self, condition: Callable[[int], bool], use_raw_node_id: Literal[True], static: bool = False, ) -> None: ... @overload def __init__( self, condition: Callable[..., bool], use_raw_node_id: bool = False, static: bool = False, ) -> None: ... def __init__( self, condition: Callable[..., bool], use_raw_node_id: bool = False, static: bool = False, ) -> None: self.condition = condition self.use_raw_node_id = use_raw_node_id self.static = static def clone(self) -> "Predicate": return self.__class__(self.condition, self.use_raw_node_id, self.static) def __call__(self, node: Union[ast.AST, int]) -> bool: node_or_id = ( id(node) if self.use_raw_node_id and isinstance(node, ast.AST) else node ) return self.condition(node_or_id) # type: ignore def dynamic_call(self, node: Union[ast.AST, int]) -> bool: return True if self.static else self(node)
Predicate.TRUE = Predicate(lambda *_: True) Predicate.FALSE = Predicate(lambda *_: False)
[docs] class CompositePredicate(Predicate): def __init__(self, base_predicates: Sequence[Predicate], reducer=any) -> None: self.base_predicates = list(base_predicates) self.dynamic_base_predicates = [ pred for pred in base_predicates if not pred.static ] self.static = len(self.dynamic_base_predicates) == 0 self.use_raw_node_id = all(pred.use_raw_node_id for pred in base_predicates) self.reducer = reducer def __call__( self, node: Union[ast.AST, int], predicates: Optional[Sequence[Predicate]] = None, ) -> bool: predicates = self.base_predicates if predicates is None else predicates assert len(predicates) > 0 return self.reducer(pred(node) for pred in predicates) def dynamic_call(self, node: Union[ast.AST, int]) -> bool: return ( True if self.static else self(node, predicates=self.dynamic_base_predicates) ) @classmethod def _create(cls, base_predicates: Sequence[Predicate], reducer) -> Predicate: assert len(base_predicates) > 0 return cls(base_predicates, reducer=reducer) @classmethod def any(cls, base_predicates: Sequence[Predicate]) -> Predicate: if len(base_predicates) == 0 or any( pred is Predicate.TRUE for pred in base_predicates ): return Predicate.TRUE if all(pred is Predicate.FALSE for pred in base_predicates): return Predicate.FALSE return cls._create(base_predicates, reducer=any) @classmethod def all(cls, base_predicates: Sequence[Predicate]) -> Predicate: if any(pred is Predicate.FALSE for pred in base_predicates): return Predicate.FALSE if all(pred is Predicate.TRUE for pred in base_predicates): return Predicate.TRUE return cls._create(base_predicates, reducer=all)