diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 59b2349c59..99dbe37161 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -220,6 +220,7 @@ class NodeNG(object): :type: bool """ + is_lambda = False # Attributes below are set by the builder module or by raw factories lineno = None """The line that this node appears on in the source code. @@ -627,12 +628,44 @@ def nodes_of_class(self, klass, skip_klass=None): """ if isinstance(self, klass): yield self + + if skip_klass is None: + for child_node in self.get_children(): + for matching in child_node.nodes_of_class(klass, skip_klass): + yield matching + + return + for child_node in self.get_children(): - if skip_klass is not None and isinstance(child_node, skip_klass): + if isinstance(child_node, skip_klass): continue for matching in child_node.nodes_of_class(klass, skip_klass): yield matching + def _get_assign_nodes(self): + for child_node in self.get_children(): + for matching in child_node._get_assign_nodes(): + yield matching + + def _get_name_nodes(self): + for child_node in self.get_children(): + for matching in child_node._get_name_nodes(): + yield matching + + def _get_return_nodes_skip_functions(self): + for child_node in self.get_children(): + if child_node.is_function: + continue + for matching in child_node._get_return_nodes_skip_functions(): + yield matching + + def _get_yield_nodes_skip_lambdas(self): + for child_node in self.get_children(): + if child_node.is_lambda: + continue + for matching in child_node._get_yield_nodes_skip_lambdas(): + yield matching + def _infer_name(self, frame, name): # overridden for ImportFrom, Import, Global, TryExcept and Arguments return None @@ -965,6 +998,9 @@ def pytype(self): :rtype: str """ + def get_children(self): + yield from self.elts + class LookupMixIn(object): """Mixin to look up a name in the right scope.""" @@ -1170,6 +1206,9 @@ def __init__(self, name=None, lineno=None, col_offset=None, parent=None): super(AssignName, self).__init__(lineno, col_offset, parent) + def get_children(self): + yield from () + class DelName(LookupMixIn, mixins.ParentAssignTypeMixin, NodeNG): """Variation of :class:`ast.Delete` represention deletion of a name. @@ -1207,6 +1246,9 @@ def __init__(self, name=None, lineno=None, col_offset=None, parent=None): super(DelName, self).__init__(lineno, col_offset, parent) + def get_children(self): + yield from () + class Name(LookupMixIn, NodeNG): """Class representing an :class:`ast.Name` node. @@ -1247,6 +1289,16 @@ def __init__(self, name=None, lineno=None, col_offset=None, parent=None): super(Name, self).__init__(lineno, col_offset, parent) + def get_children(self): + yield from () + + def _get_name_nodes(self): + yield self + + for child_node in self.get_children(): + for matching in child_node._get_name_nodes(): + yield matching + class Arguments(mixins.AssignTypeMixin, NodeNG): """Class representing an :class:`ast.arguments` node. @@ -1489,16 +1541,28 @@ def find_argname(self, argname, rec=False): return None, None def get_children(self): - """Get the child nodes below this node. + yield from self.args or () - This skips over `None` elements in :attr:`kw_defaults`. + yield from self.defaults + yield from self.kwonlyargs - :returns: The children. - :rtype: iterable(NodeNG) - """ - for child in super(Arguments, self).get_children(): - if child is not None: - yield child + if self.varargannotation is not None: + yield self.varargannotation + + if self.kwargannotation is not None: + yield self.kwargannotation + + for elt in self.kw_defaults: + if elt is not None: + yield elt + + for elt in self.annotations: + if elt is not None: + yield elt + + for elt in self.kwonlyargs_annotations: + if elt is not None: + yield elt def _find_arg(argname, args, rec=False): @@ -1587,6 +1651,9 @@ def postinit(self, expr=None): """ self.expr = expr + def get_children(self): + yield self.expr + class Assert(Statement): """Class representing an :class:`ast.Assert` node. @@ -1621,6 +1688,12 @@ def postinit(self, test=None, fail=None): self.fail = fail self.test = test + def get_children(self): + yield self.test + + if self.fail is not None: + yield self.fail + class Assign(mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.Assign` node. @@ -1656,6 +1729,18 @@ def postinit(self, targets=None, value=None): self.targets = targets self.value = value + def get_children(self): + yield from self.targets + + yield self.value + + def _get_assign_nodes(self): + yield self + + for child_node in self.get_children(): + for matching in child_node._get_assign_nodes(): + yield matching + class AnnAssign(mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.AnnAssign` node. @@ -1711,6 +1796,13 @@ def postinit(self, target, annotation, simple, value=None): self.value = value self.simple = simple + def get_children(self): + yield self.target + yield self.annotation + + if self.value is not None: + yield self.value + class AugAssign(mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.AugAssign` node. @@ -1792,6 +1884,10 @@ def type_errors(self, context=None): except exceptions.InferenceError: return [] + def get_children(self): + yield self.target + yield self.value + class Repr(NodeNG): """Class representing an :class:`ast.Repr` node. @@ -1896,6 +1992,10 @@ def type_errors(self, context=None): except exceptions.InferenceError: return [] + def get_children(self): + yield self.left + yield self.right + class BoolOp(NodeNG): """Class representing an :class:`ast.BoolOp` node. @@ -1945,6 +2045,9 @@ def postinit(self, values=None): """ self.values = values + def get_children(self): + yield from self.values + class Break(Statement): """Class representing an :class:`ast.Break` node. @@ -1954,6 +2057,9 @@ class Break(Statement): """ + def get_children(self): + yield from () + class Call(NodeNG): """Class representing an :class:`ast.Call` node. @@ -2015,6 +2121,13 @@ def kwargs(self): keywords = self.keywords or [] return [keyword for keyword in keywords if keyword.arg is None] + def get_children(self): + yield self.func + + yield from self.args + + yield from self.keywords or () + class Compare(NodeNG): """Class representing an :class:`ast.Compare` node. @@ -2183,6 +2296,12 @@ def _get_filtered_stmts(self, lookup_node, node, stmts, mystmt): return stmts, False + def get_children(self): + yield self.target + yield self.iter + + yield from self.ifs + class Const(NodeNG, bases.Instance): """Class representing any constant including num, str, bool, None, bytes. @@ -2295,6 +2414,9 @@ def bool_value(self): """ return bool(self.value) + def get_children(self): + yield from () + class Continue(Statement): """Class representing an :class:`ast.Continue` node. @@ -2304,6 +2426,9 @@ class Continue(Statement): """ + def get_children(self): + yield from () + class Decorators(NodeNG): """A node representing a list of decorators. @@ -2345,6 +2470,9 @@ def scope(self): # skip the function node to go directly to the upper level scope return self.parent.parent.scope() + def get_children(self): + yield from self.nodes + class DelAttr(mixins.ParentAssignTypeMixin, NodeNG): """Variation of :class:`ast.Delete` representing deletion of an attribute. @@ -2394,6 +2522,9 @@ def postinit(self, expr=None): """ self.expr = expr + def get_children(self): + yield self.expr + class Delete(mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.Delete` node. @@ -2419,6 +2550,9 @@ def postinit(self, targets=None): """ self.targets = targets + def get_children(self): + yield from self.targets + class Dict(NodeNG, bases.Instance): """Class representing an :class:`ast.Dict` node. @@ -2579,6 +2713,9 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + yield self.value + class Ellipsis(NodeNG): # pylint: disable=redefined-builtin """Class representing an :class:`ast.Ellipsis` node. @@ -2599,12 +2736,18 @@ def bool_value(self): """ return True + def get_children(self): + yield from () + class EmptyNode(NodeNG): """Holds an arbitrary object in the :attr:`LocalsDictNodeNG.locals`.""" object = None + def get_children(self): + yield from () + class ExceptHandler(mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.ExceptHandler`. node. @@ -2639,6 +2782,15 @@ class ExceptHandler(mixins.AssignTypeMixin, Statement): :type: list(NodeNG) or None """ + def get_children(self): + if self.type is not None: + yield self.type + + if self.name is not None: + yield self.name + + yield from self.body + # pylint: disable=redefined-builtin; had to use the same name as builtin ast module. def postinit(self, type=None, name=None, body=None): """Do some setup after initialisation. @@ -2679,7 +2831,7 @@ def catch(self, exceptions): # pylint: disable=redefined-outer-name """ if self.type is None or exceptions is None: return True - for node in self.type.nodes_of_class(Name): + for node in self.type._get_name_nodes(): if node.name in exceptions: return True return False @@ -2820,6 +2972,13 @@ def blockstart_tolineno(self): """ return self.iter.tolineno + def get_children(self): + yield self.target + yield self.iter + + yield from self.body + yield from self.orelse + class AsyncFor(For): """Class representing an :class:`ast.AsyncFor` node. @@ -2871,6 +3030,9 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + yield self.value + class ImportFrom(mixins.ImportFromMixin, Statement): """Class representing an :class:`ast.ImportFrom` node. @@ -2931,6 +3093,9 @@ def __init__(self, fromname, names, level=0, lineno=None, super(ImportFrom, self).__init__(lineno, col_offset, parent) + def get_children(self): + yield from () + class Attribute(NodeNG): """Class representing an :class:`ast.Attribute` node.""" @@ -2973,6 +3138,9 @@ def postinit(self, expr=None): """ self.expr = expr + def get_children(self): + yield self.expr + class Global(Statement): """Class representing an :class:`ast.Global` node. @@ -3009,6 +3177,9 @@ def __init__(self, names, lineno=None, col_offset=None, parent=None): def _infer_name(self, frame, name): return name + def get_children(self): + yield from () + class If(mixins.BlockRangeMixIn, Statement): """Class representing an :class:`ast.If` node. @@ -3075,6 +3246,12 @@ def block_range(self, lineno): return self._elsed_block_range(lineno, self.orelse, self.body[0].fromlineno - 1) + def get_children(self): + yield self.test + + yield from self.body + yield from self.orelse + class IfExp(NodeNG): """Class representing an :class:`ast.IfExp` node. @@ -3116,6 +3293,12 @@ def postinit(self, test=None, body=None, orelse=None): self.body = body self.orelse = orelse + def get_children(self): + yield self.test + yield self.body + yield self.orelse + + class Import(mixins.ImportFromMixin, Statement): """Class representing an :class:`ast.Import` node. @@ -3151,6 +3334,9 @@ def __init__(self, names=None, lineno=None, col_offset=None, parent=None): super(Import, self).__init__(lineno, col_offset, parent) + def get_children(self): + yield from () + class Index(NodeNG): """Class representing an :class:`ast.Index` node. @@ -3178,6 +3364,9 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + yield self.value + class Keyword(NodeNG): """Class representing an :class:`ast.keyword` node. @@ -3227,6 +3416,9 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + yield self.value + class List(_BaseContainer): """Class representing an :class:`ast.List` node. @@ -3318,6 +3510,9 @@ def __init__(self, names, lineno=None, col_offset=None, parent=None): def _infer_name(self, frame, name): return name + def get_children(self): + yield from () + class Pass(Statement): """Class representing an :class:`ast.Pass` node. @@ -3327,6 +3522,9 @@ class Pass(Statement): """ + def get_children(self): + yield from () + class Print(Statement): """Class representing an :class:`ast.Print` node. @@ -3423,11 +3621,18 @@ def raises_not_implemented(self): """ if not self.exc: return False - for name in self.exc.nodes_of_class(Name): + for name in self.exc._get_name_nodes(): if name.name == 'NotImplementedError': return True return False + def get_children(self): + if self.exc is not None: + yield self.exc + + if self.cause is not None: + yield self.cause + class Return(Statement): """Class representing an :class:`ast.Return` node. @@ -3451,6 +3656,19 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + if self.value is not None: + yield self.value + + def _get_return_nodes_skip_functions(self): + yield self + + for child_node in self.get_children(): + if child_node.is_function: + continue + for matching in child_node._get_return_nodes_skip_functions(): + yield matching + class Set(_BaseContainer): """Class representing an :class:`ast.Set` node. @@ -3554,6 +3772,16 @@ def igetattr(self, attrname, context=None): def getattr(self, attrname, context=None): return self._proxied.getattr(attrname, context) + def get_children(self): + if self.lower is not None: + yield self.lower + + if self.step is not None: + yield self.step + + if self.upper is not None: + yield self.upper + class Starred(mixins.ParentAssignTypeMixin, NodeNG): """Class representing an :class:`ast.Starred` node. @@ -3602,6 +3830,9 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + yield self.value + class Subscript(NodeNG): """Class representing an :class:`ast.Subscript` node. @@ -3660,6 +3891,10 @@ def postinit(self, value=None, slice=None): self.value = value self.slice = slice + def get_children(self): + yield self.value + yield self.slice + class TryExcept(mixins.BlockRangeMixIn, Statement): """Class representing an :class:`ast.TryExcept` node. @@ -3729,6 +3964,12 @@ def block_range(self, lineno): last = exhandler.body[0].fromlineno - 1 return self._elsed_block_range(lineno, self.orelse, last) + def get_children(self): + yield from self.body + + yield from self.handlers or () + yield from self.orelse or () + class TryFinally(mixins.BlockRangeMixIn, Statement): """Class representing an :class:`ast.TryFinally` node. @@ -3785,6 +4026,10 @@ def block_range(self, lineno): return child.block_range(lineno) return self._elsed_block_range(lineno, self.finalbody) + def get_children(self): + yield from self.body + yield from self.finalbody + class Tuple(_BaseContainer): """Class representing an :class:`ast.Tuple` node. @@ -3903,6 +4148,9 @@ def type_errors(self, context=None): except exceptions.InferenceError: return [] + def get_children(self): + yield self.operand + class While(mixins.BlockRangeMixIn, Statement): """Class representing an :class:`ast.While` node. @@ -3967,6 +4215,12 @@ def block_range(self, lineno): """ return self. _elsed_block_range(lineno, self.orelse) + def get_children(self): + yield self.test + + yield from self.body + yield from self.orelse + class With(mixins.BlockRangeMixIn, mixins.AssignTypeMixin, Statement): """Class representing an :class:`ast.With` node. @@ -4051,6 +4305,19 @@ def postinit(self, value=None): """ self.value = value + def get_children(self): + if self.value is not None: + yield self.value + + def _get_yield_nodes_skip_lambdas(self): + yield self + + for child_node in self.get_children(): + if child_node.is_function_or_lambda: + continue + for matching in child_node._get_yield_nodes_skip_lambdas(): + yield matching + class YieldFrom(Yield): """Class representing an :class:`ast.YieldFrom` node.""" @@ -4059,6 +4326,9 @@ class YieldFrom(Yield): class DictUnpack(NodeNG): """Represents the unpacking of dicts into dicts using :pep:`448`.""" + def get_children(self): + yield from () + class FormattedValue(NodeNG): """Class representing an :class:`ast.FormattedValue` node. @@ -4110,6 +4380,12 @@ def postinit(self, value, conversion=None, format_spec=None): self.conversion = conversion self.format_spec = format_spec + def get_children(self): + yield self.value + + if self.format_spec is not None: + yield self.format_spec + class JoinedStr(NodeNG): """Represents a list of string expressions to be joined. @@ -4134,6 +4410,9 @@ def postinit(self, values=None): """ self.values = values + def get_children(self): + yield from self.values + class Unknown(mixins.AssignTypeMixin, NodeNG): """This node represents a node in a constructed AST where diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index 40970671b3..8486b8edac 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -696,6 +696,10 @@ def bool_value(self): """ return True + def get_children(self): + for elt in self.body: + yield elt + class ComprehensionScope(LocalsDictNodeNG): """Scoping for different types of comprehensions.""" @@ -777,6 +781,11 @@ def bool_value(self): """ return True + def get_children(self): + yield self.elt + + yield from self.generators + class DictComp(ComprehensionScope): """Class representing an :class:`ast.DictComp` node. @@ -851,6 +860,12 @@ def bool_value(self): """ return util.Uninferable + def get_children(self): + yield self.key + yield self.value + + yield from self.generators + class SetComp(ComprehensionScope): """Class representing an :class:`ast.SetComp` node. @@ -916,6 +931,11 @@ def bool_value(self): """ return util.Uninferable + def get_children(self): + yield self.elt + + yield from self.generators + class _ListComp(node_classes.NodeNG): """Class representing an :class:`ast.ListComp` node. @@ -957,6 +977,11 @@ def bool_value(self): """ return util.Uninferable + def get_children(self): + yield self.elt + + yield from self.generators + class ListComp(_ListComp, ComprehensionScope): """Class representing an :class:`ast.ListComp` node. @@ -1012,6 +1037,7 @@ class Lambda(mixins.FilterStmtsMixin, LocalsDictNodeNG): _astroid_fields = ('args', 'body',) _other_other_fields = ('locals',) name = '' + is_lambda = True # function's type, 'function' | 'method' | 'staticmethod' | 'classmethod' @property @@ -1174,6 +1200,10 @@ def bool_value(self): """ return True + def get_children(self): + yield self.args + yield self.body + class FunctionDef(node_classes.Statement, Lambda): """Class representing an :class:`ast.FunctionDef`. @@ -1284,7 +1314,7 @@ def extra_decorators(self): return [] decorators = [] - for assign in frame.nodes_of_class(node_classes.Assign): + for assign in frame._get_assign_nodes(): if (isinstance(assign.value, node_classes.Call) and isinstance(assign.value.func, node_classes.Name)): for assign_node in assign.targets: @@ -1501,9 +1531,7 @@ def is_generator(self): :returns: True is this is a generator function, False otherwise. :rtype: bool """ - yield_nodes = (node_classes.Yield, node_classes.YieldFrom) - return next(self.nodes_of_class(yield_nodes, - skip_klass=(FunctionDef, Lambda)), False) + return next(self._get_yield_nodes_skip_lambdas(), False) def infer_call_result(self, caller=None, context=None): """Infer what the function returns when called. @@ -1534,7 +1562,7 @@ def infer_call_result(self, caller=None, context=None): c._metaclass = metaclass yield c return - returns = self.nodes_of_class(node_classes.Return, skip_klass=FunctionDef) + returns = self._get_return_nodes_skip_functions() for returnnode in returns: if returnnode.value is None: yield node_classes.Const(None) @@ -1554,6 +1582,18 @@ def bool_value(self): """ return True + def get_children(self): + if self.decorators is not None: + yield self.decorators + + yield self.args + + if self.returns is not None: + yield self.returns + + for elt in self.body: + yield elt + class AsyncFunctionDef(FunctionDef): """Class representing an :class:`ast.FunctionDef` node. @@ -2649,6 +2689,16 @@ def bool_value(self): """ return True + def get_children(self): + for elt in self.body: + yield elt + + for elt in self.bases: + yield elt + + if self.decorators is not None: + yield self.decorators + # Backwards-compatibility aliases Class = util.proxy_alias('Class', ClassDef)