source.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. from __future__ import generators
  2. from bisect import bisect_right
  3. import sys
  4. import inspect, tokenize
  5. import py
  6. from types import ModuleType
  7. cpy_compile = compile
  8. try:
  9. import _ast
  10. from _ast import PyCF_ONLY_AST as _AST_FLAG
  11. except ImportError:
  12. _AST_FLAG = 0
  13. _ast = None
  14. class Source(object):
  15. """ a immutable object holding a source code fragment,
  16. possibly deindenting it.
  17. """
  18. _compilecounter = 0
  19. def __init__(self, *parts, **kwargs):
  20. self.lines = lines = []
  21. de = kwargs.get('deindent', True)
  22. rstrip = kwargs.get('rstrip', True)
  23. for part in parts:
  24. if not part:
  25. partlines = []
  26. if isinstance(part, Source):
  27. partlines = part.lines
  28. elif isinstance(part, (tuple, list)):
  29. partlines = [x.rstrip("\n") for x in part]
  30. elif isinstance(part, py.builtin._basestring):
  31. partlines = part.split('\n')
  32. if rstrip:
  33. while partlines:
  34. if partlines[-1].strip():
  35. break
  36. partlines.pop()
  37. else:
  38. partlines = getsource(part, deindent=de).lines
  39. if de:
  40. partlines = deindent(partlines)
  41. lines.extend(partlines)
  42. def __eq__(self, other):
  43. try:
  44. return self.lines == other.lines
  45. except AttributeError:
  46. if isinstance(other, str):
  47. return str(self) == other
  48. return False
  49. def __getitem__(self, key):
  50. if isinstance(key, int):
  51. return self.lines[key]
  52. else:
  53. if key.step not in (None, 1):
  54. raise IndexError("cannot slice a Source with a step")
  55. return self.__getslice__(key.start, key.stop)
  56. def __len__(self):
  57. return len(self.lines)
  58. def __getslice__(self, start, end):
  59. newsource = Source()
  60. newsource.lines = self.lines[start:end]
  61. return newsource
  62. def strip(self):
  63. """ return new source object with trailing
  64. and leading blank lines removed.
  65. """
  66. start, end = 0, len(self)
  67. while start < end and not self.lines[start].strip():
  68. start += 1
  69. while end > start and not self.lines[end-1].strip():
  70. end -= 1
  71. source = Source()
  72. source.lines[:] = self.lines[start:end]
  73. return source
  74. def putaround(self, before='', after='', indent=' ' * 4):
  75. """ return a copy of the source object with
  76. 'before' and 'after' wrapped around it.
  77. """
  78. before = Source(before)
  79. after = Source(after)
  80. newsource = Source()
  81. lines = [ (indent + line) for line in self.lines]
  82. newsource.lines = before.lines + lines + after.lines
  83. return newsource
  84. def indent(self, indent=' ' * 4):
  85. """ return a copy of the source object with
  86. all lines indented by the given indent-string.
  87. """
  88. newsource = Source()
  89. newsource.lines = [(indent+line) for line in self.lines]
  90. return newsource
  91. def getstatement(self, lineno, assertion=False):
  92. """ return Source statement which contains the
  93. given linenumber (counted from 0).
  94. """
  95. start, end = self.getstatementrange(lineno, assertion)
  96. return self[start:end]
  97. def getstatementrange(self, lineno, assertion=False):
  98. """ return (start, end) tuple which spans the minimal
  99. statement region which containing the given lineno.
  100. """
  101. if not (0 <= lineno < len(self)):
  102. raise IndexError("lineno out of range")
  103. ast, start, end = getstatementrange_ast(lineno, self)
  104. return start, end
  105. def deindent(self, offset=None):
  106. """ return a new source object deindented by offset.
  107. If offset is None then guess an indentation offset from
  108. the first non-blank line. Subsequent lines which have a
  109. lower indentation offset will be copied verbatim as
  110. they are assumed to be part of multilines.
  111. """
  112. # XXX maybe use the tokenizer to properly handle multiline
  113. # strings etc.pp?
  114. newsource = Source()
  115. newsource.lines[:] = deindent(self.lines, offset)
  116. return newsource
  117. def isparseable(self, deindent=True):
  118. """ return True if source is parseable, heuristically
  119. deindenting it by default.
  120. """
  121. try:
  122. import parser
  123. except ImportError:
  124. syntax_checker = lambda x: compile(x, 'asd', 'exec')
  125. else:
  126. syntax_checker = parser.suite
  127. if deindent:
  128. source = str(self.deindent())
  129. else:
  130. source = str(self)
  131. try:
  132. #compile(source+'\n', "x", "exec")
  133. syntax_checker(source+'\n')
  134. except KeyboardInterrupt:
  135. raise
  136. except Exception:
  137. return False
  138. else:
  139. return True
  140. def __str__(self):
  141. return "\n".join(self.lines)
  142. def compile(self, filename=None, mode='exec',
  143. flag=generators.compiler_flag,
  144. dont_inherit=0, _genframe=None):
  145. """ return compiled code object. if filename is None
  146. invent an artificial filename which displays
  147. the source/line position of the caller frame.
  148. """
  149. if not filename or py.path.local(filename).check(file=0):
  150. if _genframe is None:
  151. _genframe = sys._getframe(1) # the caller
  152. fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno
  153. base = "<%d-codegen " % self._compilecounter
  154. self.__class__._compilecounter += 1
  155. if not filename:
  156. filename = base + '%s:%d>' % (fn, lineno)
  157. else:
  158. filename = base + '%r %s:%d>' % (filename, fn, lineno)
  159. source = "\n".join(self.lines) + '\n'
  160. try:
  161. co = cpy_compile(source, filename, mode, flag)
  162. except SyntaxError:
  163. ex = sys.exc_info()[1]
  164. # re-represent syntax errors from parsing python strings
  165. msglines = self.lines[:ex.lineno]
  166. if ex.offset:
  167. msglines.append(" "*ex.offset + '^')
  168. msglines.append("(code was compiled probably from here: %s)" % filename)
  169. newex = SyntaxError('\n'.join(msglines))
  170. newex.offset = ex.offset
  171. newex.lineno = ex.lineno
  172. newex.text = ex.text
  173. raise newex
  174. else:
  175. if flag & _AST_FLAG:
  176. return co
  177. lines = [(x + "\n") for x in self.lines]
  178. import linecache
  179. linecache.cache[filename] = (1, None, lines, filename)
  180. return co
  181. #
  182. # public API shortcut functions
  183. #
  184. def compile_(source, filename=None, mode='exec', flags=
  185. generators.compiler_flag, dont_inherit=0):
  186. """ compile the given source to a raw code object,
  187. and maintain an internal cache which allows later
  188. retrieval of the source code for the code object
  189. and any recursively created code objects.
  190. """
  191. if _ast is not None and isinstance(source, _ast.AST):
  192. # XXX should Source support having AST?
  193. return cpy_compile(source, filename, mode, flags, dont_inherit)
  194. _genframe = sys._getframe(1) # the caller
  195. s = Source(source)
  196. co = s.compile(filename, mode, flags, _genframe=_genframe)
  197. return co
  198. def getfslineno(obj):
  199. """ Return source location (path, lineno) for the given object.
  200. If the source cannot be determined return ("", -1)
  201. """
  202. try:
  203. code = py.code.Code(obj)
  204. except TypeError:
  205. try:
  206. fn = (inspect.getsourcefile(obj) or
  207. inspect.getfile(obj))
  208. except TypeError:
  209. return "", -1
  210. fspath = fn and py.path.local(fn) or None
  211. lineno = -1
  212. if fspath:
  213. try:
  214. _, lineno = findsource(obj)
  215. except IOError:
  216. pass
  217. else:
  218. fspath = code.path
  219. lineno = code.firstlineno
  220. assert isinstance(lineno, int)
  221. return fspath, lineno
  222. #
  223. # helper functions
  224. #
  225. def findsource(obj):
  226. try:
  227. sourcelines, lineno = inspect.findsource(obj)
  228. except py.builtin._sysex:
  229. raise
  230. except:
  231. return None, -1
  232. source = Source()
  233. source.lines = [line.rstrip() for line in sourcelines]
  234. return source, lineno
  235. def getsource(obj, **kwargs):
  236. obj = py.code.getrawcode(obj)
  237. try:
  238. strsrc = inspect.getsource(obj)
  239. except IndentationError:
  240. strsrc = "\"Buggy python version consider upgrading, cannot get source\""
  241. assert isinstance(strsrc, str)
  242. return Source(strsrc, **kwargs)
  243. def deindent(lines, offset=None):
  244. if offset is None:
  245. for line in lines:
  246. line = line.expandtabs()
  247. s = line.lstrip()
  248. if s:
  249. offset = len(line)-len(s)
  250. break
  251. else:
  252. offset = 0
  253. if offset == 0:
  254. return list(lines)
  255. newlines = []
  256. def readline_generator(lines):
  257. for line in lines:
  258. yield line + '\n'
  259. while True:
  260. yield ''
  261. it = readline_generator(lines)
  262. try:
  263. for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)):
  264. if sline > len(lines):
  265. break # End of input reached
  266. if sline > len(newlines):
  267. line = lines[sline - 1].expandtabs()
  268. if line.lstrip() and line[:offset].isspace():
  269. line = line[offset:] # Deindent
  270. newlines.append(line)
  271. for i in range(sline, eline):
  272. # Don't deindent continuing lines of
  273. # multiline tokens (i.e. multiline strings)
  274. newlines.append(lines[i])
  275. except (IndentationError, tokenize.TokenError):
  276. pass
  277. # Add any lines we didn't see. E.g. if an exception was raised.
  278. newlines.extend(lines[len(newlines):])
  279. return newlines
  280. def get_statement_startend2(lineno, node):
  281. import ast
  282. # flatten all statements and except handlers into one lineno-list
  283. # AST's line numbers start indexing at 1
  284. l = []
  285. for x in ast.walk(node):
  286. if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
  287. l.append(x.lineno - 1)
  288. for name in "finalbody", "orelse":
  289. val = getattr(x, name, None)
  290. if val:
  291. # treat the finally/orelse part as its own statement
  292. l.append(val[0].lineno - 1 - 1)
  293. l.sort()
  294. insert_index = bisect_right(l, lineno)
  295. start = l[insert_index - 1]
  296. if insert_index >= len(l):
  297. end = None
  298. else:
  299. end = l[insert_index]
  300. return start, end
  301. def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
  302. if astnode is None:
  303. content = str(source)
  304. try:
  305. astnode = compile(content, "source", "exec", 1024) # 1024 for AST
  306. except ValueError:
  307. start, end = getstatementrange_old(lineno, source, assertion)
  308. return None, start, end
  309. start, end = get_statement_startend2(lineno, astnode)
  310. # we need to correct the end:
  311. # - ast-parsing strips comments
  312. # - there might be empty lines
  313. # - we might have lesser indented code blocks at the end
  314. if end is None:
  315. end = len(source.lines)
  316. if end > start + 1:
  317. # make sure we don't span differently indented code blocks
  318. # by using the BlockFinder helper used which inspect.getsource() uses itself
  319. block_finder = inspect.BlockFinder()
  320. # if we start with an indented line, put blockfinder to "started" mode
  321. block_finder.started = source.lines[start][0].isspace()
  322. it = ((x + "\n") for x in source.lines[start:end])
  323. try:
  324. for tok in tokenize.generate_tokens(lambda: next(it)):
  325. block_finder.tokeneater(*tok)
  326. except (inspect.EndOfBlock, IndentationError):
  327. end = block_finder.last + start
  328. except Exception:
  329. pass
  330. # the end might still point to a comment or empty line, correct it
  331. while end:
  332. line = source.lines[end - 1].lstrip()
  333. if line.startswith("#") or not line:
  334. end -= 1
  335. else:
  336. break
  337. return astnode, start, end
  338. def getstatementrange_old(lineno, source, assertion=False):
  339. """ return (start, end) tuple which spans the minimal
  340. statement region which containing the given lineno.
  341. raise an IndexError if no such statementrange can be found.
  342. """
  343. # XXX this logic is only used on python2.4 and below
  344. # 1. find the start of the statement
  345. from codeop import compile_command
  346. for start in range(lineno, -1, -1):
  347. if assertion:
  348. line = source.lines[start]
  349. # the following lines are not fully tested, change with care
  350. if 'super' in line and 'self' in line and '__init__' in line:
  351. raise IndexError("likely a subclass")
  352. if "assert" not in line and "raise" not in line:
  353. continue
  354. trylines = source.lines[start:lineno+1]
  355. # quick hack to prepare parsing an indented line with
  356. # compile_command() (which errors on "return" outside defs)
  357. trylines.insert(0, 'def xxx():')
  358. trysource = '\n '.join(trylines)
  359. # ^ space here
  360. try:
  361. compile_command(trysource)
  362. except (SyntaxError, OverflowError, ValueError):
  363. continue
  364. # 2. find the end of the statement
  365. for end in range(lineno+1, len(source)+1):
  366. trysource = source[start:end]
  367. if trysource.isparseable():
  368. return start, end
  369. raise SyntaxError("no valid source range around line %d " % (lineno,))