# -*- coding: utf-8 -*-
import ast
import bisect
import logging
from collections import defaultdict
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Callable,
DefaultDict,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from pyccolo.ast_bookkeeping import AstBookkeeper, BookkeepingVisitor
from pyccolo.expr_rewriter import ExprRewriter
from pyccolo.handler import HandlerSpec
from pyccolo.predicate import CompositePredicate, Predicate
from pyccolo.stmt_inserter import StatementInserter
from pyccolo.stmt_mapper import StatementMapper
from pyccolo.syntax_augmentation import (
AugmentationSpec,
AugmentationType,
Position,
Range,
fix_positions,
)
from pyccolo.trace_events import TraceEvent
if TYPE_CHECKING:
from pyccolo.tracer import BaseTracer
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
_T = TypeVar("_T")
GUARD_DATA_T = Tuple[HandlerSpec, Callable[[Union[int, ast.AST]], str]]
[docs]
class AstRewriter(ast.NodeTransformer):
gc_bookkeeping = True
def __init__(
self,
tracers: "List[BaseTracer]",
path: str,
module_id: Optional[int] = None,
) -> None:
self._tracers = tracers
self._path = path
self._module_id = module_id
self._augmented_positions_by_spec: Dict[AugmentationSpec, Set[Position]] = (
defaultdict(set)
)
self.orig_to_copy_mapping: Optional[Dict[int, ast.AST]] = None
@contextmanager
def tracer_override_context(
self, tracers: List["BaseTracer"], path: str
) -> Generator[None, None, None]:
orig_tracers = self._tracers
orig_path = self._path
self._tracers = tracers
self._path = path
try:
yield
finally:
self._tracers = orig_tracers
self._path = orig_path
def _get_order_of_specs_applied(self) -> Tuple[AugmentationSpec, ...]:
specs = []
for tracer in self._tracers:
for spec in tracer.last_applied_specs:
if spec not in specs:
specs.append(spec)
return tuple(specs)
def register_augmented_position(
self, aug_spec: AugmentationSpec, lineno: int, col_offset: int
) -> None:
self._augmented_positions_by_spec[aug_spec].add(Position(lineno, col_offset))
def _make_node_copy_flyweight(
self, predicate: Callable[..., _T]
) -> Callable[..., _T]:
return lambda node_or_id: predicate(
(self.orig_to_copy_mapping or {}).get(
node_or_id if isinstance(node_or_id, int) else id(node_or_id),
node_or_id,
)
)
def should_instrument_with_tracer(self, tracer: "BaseTracer") -> bool:
return self._path is None or tracer._should_instrument_file_impl(self._path)
@staticmethod
def _get_prefix_range_for(
node: ast.AST,
) -> Optional[Range]:
line, col = None, None
if isinstance(node, ast.Name):
line, col = node.lineno, node.col_offset
elif isinstance(node, ast.Attribute):
line, col = node.lineno, getattr(node.value, "end_col_offset", -2) + 1
elif isinstance(node, ast.FunctionDef):
# TODO: can be different if more spaces between 'def' and function name
line, col = node.lineno, node.col_offset + 4
elif isinstance(node, ast.ClassDef):
# TODO: can be different if more spaces between 'class' and class name
line, col = node.lineno, node.col_offset + 6
elif isinstance(node, ast.AsyncFunctionDef):
# TODO: can be different if more spaces between 'async', 'def', and function name
line, col = node.lineno, node.col_offset + 10
elif isinstance(node, (ast.Import, ast.ImportFrom)) and len(node.names) == 1:
# "import " vs "from <base_module> import "
base_offset = (
7 if isinstance(node, ast.Import) else 13 + len(node.module or "")
)
name = node.names[0]
line, col = node.lineno, (
node.col_offset
+ base_offset
+ (0 if name.asname is None else len(name.name) + 1)
)
if line is None or col is None:
return None
else:
return Range.singleton_span(line, col)
@staticmethod
def _get_suffix_range_for(
node: ast.AST,
) -> Optional[Range]:
line, col = None, None
if isinstance(node, ast.Name):
line, col = node.lineno, node.col_offset + len(node.id)
elif isinstance(node, ast.Attribute):
line, col = (
node.lineno,
getattr(node.value, "end_col_offset", -1) + len(node.attr) + 1,
)
elif isinstance(node, ast.FunctionDef):
# TODO: can be different if more spaces between 'def' and function name
line, col = node.lineno, node.col_offset + 4 + len(node.name)
elif isinstance(node, ast.ClassDef):
# TODO: can be different if more spaces between 'class' and class name
line, col = node.lineno, node.col_offset + 6 + len(node.name)
elif isinstance(node, ast.AsyncFunctionDef):
# TODO: can be different if more spaces between 'async', 'def', and function name
line, col = node.lineno, node.col_offset + 10 + len(node.name)
elif isinstance(node, (ast.Import, ast.ImportFrom)) and len(node.names) == 1:
name = node.names[0]
# "import " vs "from <base_module> import "
base_offset = (
7 if isinstance(node, ast.Import) else 13 + len(node.module or "")
)
col_offset = node.col_offset + base_offset
if name.asname is None:
col_offset += len(name.name)
else:
col_offset += len(name.name) + 1 + len(name.asname)
line, col = node.lineno, col_offset
if line is None or col is None:
return None
else:
return Range.singleton_span(line, col)
@staticmethod
def _get_dot_suffix_range_for(
node: ast.AST,
) -> Optional[Range]:
end_lineno, end_col_offset = None, None
if isinstance(node, ast.Name):
end_lineno = getattr(node, "end_lineno", None)
end_col_offset = getattr(node, "end_col_offset", None)
elif isinstance(node, (ast.Attribute, ast.Subscript)):
end_lineno = getattr(node.value, "end_lineno", None)
end_col_offset = getattr(node.value, "end_col_offset", None)
if end_lineno is None or end_col_offset is None:
return None
else:
return Range.singleton_span(end_lineno, end_col_offset)
@staticmethod
def _get_dot_prefix_range_for(
node: ast.AST,
) -> Optional[Range]:
line, col = None, None
if isinstance(node, ast.Name):
line, col = node.lineno, node.col_offset
elif isinstance(node, (ast.Attribute, ast.Subscript)):
line, col = node.value.lineno, node.value.col_offset
if line is None or col is None:
return None
else:
return Range.singleton_span(line, col)
@staticmethod
def _get_binop_range_for(
node: ast.AST,
) -> Optional[Range]:
if isinstance(node, ast.BinOp):
left_end_lineno: Optional[int] = getattr(node.left, "end_lineno", None)
left_end_col_offset: Optional[int] = getattr(
node.left, "end_col_offset", None
)
if left_end_lineno is None or left_end_col_offset is None:
return None
else:
return Range(
Position(left_end_lineno, left_end_col_offset),
Position(node.right.lineno, node.right.col_offset),
)
else:
return None
def _get_boolop_range_for(self, node: ast.AST) -> Optional[Range]:
if not hasattr(node, "col_offset") or not isinstance(node, ast.expr):
return None
parent = self._tracers[-1].containing_ast_by_id.get(id(node))
if not isinstance(parent, ast.BoolOp):
return None
try:
value_index = parent.values.index(node)
if value_index + 1 >= len(parent.values):
return None
except ValueError:
return None
sibling = parent.values[value_index + 1]
end_lineno = getattr(sibling, "end_lineno", None)
end_col_offset = getattr(sibling, "end_col_offset", None)
if end_lineno is None or end_col_offset is None:
return None
return Range(
Position(node.lineno, node.col_offset), Position(end_lineno, end_col_offset)
)
def _get_call_range_for(self, node: ast.AST) -> Optional[Range]:
if not isinstance(node, ast.Call):
return None
end_lineno: Optional[int] = getattr(node.func, "end_lineno", None)
end_col_offset: Optional[int] = getattr(node.func, "end_col_offset", None)
if end_lineno is None or end_col_offset is None:
return None
return Range.singleton_span(end_lineno, end_col_offset)
@staticmethod
def _get_subscript_range_for(node: ast.AST) -> Optional[Range]:
# The opening bracket of a subscript sits immediately after the value,
# which is exactly the position a paired ``{`` -> ``[`` swap registers.
# Relies on end_col_offset, so subscript-augmentation detection requires
# Python 3.8+; on 3.7 this returns None (the swap still works, but a
# brace-block subscript is not distinguishable via get_augmentations).
if not isinstance(node, ast.Subscript):
return None
end_lineno: Optional[int] = getattr(node.value, "end_lineno", None)
end_col_offset: Optional[int] = getattr(node.value, "end_col_offset", None)
if end_lineno is None or end_col_offset is None:
return None
return Range.singleton_span(end_lineno, end_col_offset)
def _get_range_for(
self, aug_type: AugmentationType, node: ast.AST
) -> Optional[Range]:
if aug_type == AugmentationType.prefix:
return self._get_prefix_range_for(node)
elif aug_type == AugmentationType.suffix:
return self._get_suffix_range_for(node)
elif aug_type == AugmentationType.dot_suffix:
return self._get_dot_suffix_range_for(node)
elif aug_type == AugmentationType.dot_prefix:
return self._get_dot_prefix_range_for(node)
elif aug_type == AugmentationType.binop:
return self._get_binop_range_for(node)
elif aug_type == AugmentationType.boolop:
return self._get_boolop_range_for(node)
elif aug_type == AugmentationType.call:
return self._get_call_range_for(node)
elif aug_type == AugmentationType.subscript:
return self._get_subscript_range_for(node)
else:
raise NotImplementedError()
def _range_for_spec(self, spec: AugmentationSpec, node: ast.AST) -> Optional[Range]:
"""Anchor range of ``node`` for ``spec``. A custom spec delegates to its
own ``range_for`` (it may anchor a synthesized node the built-in
``aug_type`` helpers don't know about); all other specs use the built-in
per-``aug_type`` lookup unchanged."""
if spec.is_custom:
return spec.custom.range_for(node) # type: ignore[union-attr]
return self._get_range_for(spec.aug_type, node)
def _handle_augmentations_for_node(
self,
augmented_positions_by_spec: Dict[AugmentationSpec, List[Position]],
nc: ast.AST,
) -> None:
for spec, mod_positions in augmented_positions_by_spec.items():
aug_range = self._range_for_spec(spec, nc)
if aug_range is None:
continue
start_pos, end_pos = aug_range
left_insert_point = bisect.bisect_left(mod_positions, start_pos)
right_insert_point = bisect.bisect_right(mod_positions, end_pos)
if not any(
start_pos <= mod_positions[i] <= end_pos
for i in range(
max(left_insert_point, 0),
min(right_insert_point, len(mod_positions)),
)
):
continue
for tracer in self._tracers:
if spec in tracer.syntax_augmentation_specs():
tracer.augmented_node_ids_by_spec[spec].add(id(nc))
def _handle_all_augmentations(
self, orig_to_copy_mapping: Dict[int, ast.AST]
) -> None:
augmented_positions_by_spec = fix_positions(
self._augmented_positions_by_spec,
spec_order=self._get_order_of_specs_applied(),
)
for nc in orig_to_copy_mapping.values():
self._handle_augmentations_for_node(augmented_positions_by_spec, nc)
[docs]
def visit(self, node: ast.AST, instrument: bool = True):
assert isinstance(
node, (ast.Expression, ast.Module, ast.FunctionDef, ast.AsyncFunctionDef)
)
assert self._path is not None
mapper = StatementMapper(self._tracers)
orig_to_copy_mapping = mapper(node)
last_tracer = self._tracers[-1]
old_bookkeeper = last_tracer.ast_bookkeeper_by_fname.get(self._path)
module_id = id(node) if self._module_id is None else self._module_id
# garbage collect any stale references to aug specs once they have been propagated
cleanup_bookkeeper = AstBookkeeper.create(self._path, module_id)
BookkeepingVisitor(cleanup_bookkeeper).visit(node)
last_tracer.remove_bookkeeping(cleanup_bookkeeper, module_id)
new_bookkeeper = last_tracer.ast_bookkeeper_by_fname[self._path] = (
AstBookkeeper.create(self._path, module_id)
)
if old_bookkeeper is not None and self.gc_bookkeeping:
last_tracer.remove_bookkeeping(old_bookkeeper, module_id)
BookkeepingVisitor(new_bookkeeper).visit(orig_to_copy_mapping[id(node)])
last_tracer.add_bookkeeping(new_bookkeeper, module_id)
self.orig_to_copy_mapping = orig_to_copy_mapping
self._handle_all_augmentations(orig_to_copy_mapping)
if not instrument:
# Return the augmentation-annotated *copy* (whose node ids are the keys
# in ``augmented_node_ids_by_spec``) without weaving in any tracing
# instrumentation. This is the clean tree ``untransform`` consumes.
return orig_to_copy_mapping[id(node)]
raw_handler_predicates_by_event: DefaultDict[TraceEvent, List[Predicate]] = (
defaultdict(list)
)
raw_guard_exempt_handler_predicates_by_event: DefaultDict[
TraceEvent, List[Predicate]
] = defaultdict(list)
# A tracer with guards globally disabled wants *none* of its handlers
# guarded, but that only matters when some *other* tracer enables guards
# (the rewriter inserts guards iff global_guards_enabled, which is the
# ``any`` over tracers). When no tracer wants guards there is nothing to be
# exempt from, so skip building the exempt predicates entirely -- both to
# avoid pointless work and, importantly, to leave the standalone code path
# byte-for-byte unchanged (building extra predicates perturbs object
# identities, which some id()-order-sensitive bookkeeping is fragile to).
any_guards_enabled = any(
tracer.global_guards_enabled for tracer in self._tracers
)
for tracer in self._tracers:
if not self.should_instrument_with_tracer(tracer):
continue
for evt in tracer.events_with_registered_handlers:
# this is to deal with the tests in test_trace_events.py,
# which patch events_with_registered_handlers but do not add them to _event_handlers
handler_data = tracer._event_handlers.get(
evt, [HandlerSpec.empty()] # type: ignore
)
for handler_spec in handler_data:
raw_handler_predicates_by_event[evt].append(handler_spec.predicate)
# Honor a guard-disabled tracer's intent by treating all of its
# handlers as guard-exempt -- but only when guards are actually
# in play (see above).
if handler_spec.exempt_from_guards or (
any_guards_enabled and not tracer.global_guards_enabled
):
raw_guard_exempt_handler_predicates_by_event[evt].append(
handler_spec.predicate
)
handler_predicate_by_event: DefaultDict[
TraceEvent, Callable[..., bool]
] = defaultdict(
lambda: (lambda *_: False) # type: ignore
)
guard_exempt_handler_prediate_by_event: DefaultDict[
TraceEvent, Callable[..., bool]
] = defaultdict(
lambda: (lambda *_: False) # type: ignore
)
for evt, raw_predicates in raw_handler_predicates_by_event.items():
handler_predicate_by_event[evt] = self._make_node_copy_flyweight(
CompositePredicate.any(raw_predicates)
)
for evt, raw_predicates in raw_guard_exempt_handler_predicates_by_event.items():
guard_exempt_handler_prediate_by_event[evt] = (
self._make_node_copy_flyweight(CompositePredicate.any(raw_predicates))
)
handler_guards_by_event: DefaultDict[TraceEvent, List[GUARD_DATA_T]] = (
defaultdict(list)
)
for tracer in self._tracers:
for evt, handler_specs in tracer._event_handlers.items():
handler_guards_by_event[evt].extend(
(spec, self._make_node_copy_flyweight(spec.guard))
for spec in handler_specs
if spec.guard is not None
)
if isinstance(node, ast.Module):
for tracer in self._tracers:
tracer._static_init_module_impl(
orig_to_copy_mapping.get(id(node), node) # type: ignore
)
# very important that the eavesdropper does not create new ast nodes for ast.stmt (but just
# modifies existing ones), since StatementInserter relies on being able to map these
expr_rewriter = ExprRewriter(
self._tracers,
mapper,
orig_to_copy_mapping,
handler_predicate_by_event,
guard_exempt_handler_prediate_by_event,
handler_guards_by_event,
)
if isinstance(node, ast.Expression):
node = expr_rewriter.visit(node)
else:
for i in range(len(node.body)):
node.body[i] = expr_rewriter.visit(node.body[i])
node = StatementInserter(
self._tracers,
mapper,
orig_to_copy_mapping,
handler_predicate_by_event,
guard_exempt_handler_prediate_by_event,
handler_guards_by_event,
expr_rewriter,
).visit(node)
if not any(tracer.requires_ast_bookkeeping for tracer in self._tracers):
last_tracer.remove_bookkeeping(new_bookkeeper, module_id)
return node