summary refs log tree commit diff stats
path: root/nimpretty/nimpretty.nim
diff options
context:
space:
mode:
Diffstat (limited to 'nimpretty/nimpretty.nim')
-rw-r--r--nimpretty/nimpretty.nim50
1 files changed, 47 insertions, 3 deletions
diff --git a/nimpretty/nimpretty.nim b/nimpretty/nimpretty.nim
index e3b4c9a3a..e5abf0d2d 100644
--- a/nimpretty/nimpretty.nim
+++ b/nimpretty/nimpretty.nim
@@ -12,7 +12,7 @@
 when not defined(nimpretty):
   {.error: "This needs to be compiled with --define:nimPretty".}
 
-import ../compiler / [idents, msgs, syntaxes, options, pathutils, layouter]
+import ../compiler / [idents, llstream, ast, msgs, syntaxes, options, pathutils, layouter]
 
 import parseopt, strutils, os, sequtils
 
@@ -48,6 +48,42 @@ type
     indWidth*: Natural
     maxLineLen*: Positive
 
+proc goodEnough(a, b: PNode): bool =
+  if a.kind == b.kind and a.safeLen == b.safeLen:
+    case a.kind
+    of nkNone, nkEmpty, nkNilLit: result = true
+    of nkIdent: result = a.ident.id == b.ident.id
+    of nkSym: result = a.sym == b.sym
+    of nkType: result = true
+    of nkCharLit, nkIntLit..nkInt64Lit, nkUIntLit..nkUInt64Lit:
+      result = a.intVal == b.intVal
+    of nkFloatLit..nkFloat128Lit:
+      result = a.floatVal == b.floatVal
+    of nkStrLit, nkRStrLit, nkTripleStrLit:
+      result = a.strVal == b.strVal
+    else:
+      for i in 0 ..< a.len:
+        if not goodEnough(a[i], b[i]): return false
+      return true
+  elif a.kind == nkStmtList and a.len == 1:
+    result = goodEnough(a[0], b)
+  elif b.kind == nkStmtList and b.len == 1:
+    result = goodEnough(a, b[0])
+  else:
+    result = false
+
+proc finalCheck(content: string; origAst: PNode): bool {.nimcall.} =
+  var conf = newConfigRef()
+  let oldErrors = conf.errorCounter
+  var parser: Parser
+  parser.em.indWidth = 2
+  let fileIdx = fileInfoIdx(conf, AbsoluteFile "nimpretty_bug.nim")
+
+  openParser(parser, fileIdx, llStreamOpen(content), newIdentCache(), conf)
+  let newAst = parseAll(parser)
+  closeParser(parser)
+  result = conf.errorCounter == oldErrors # and goodEnough(newAst, origAst)
+
 proc prettyPrint*(infile, outfile: string, opt: PrettyOptions) =
   var conf = newConfigRef()
   let fileIdx = fileInfoIdx(conf, AbsoluteFile infile)
@@ -58,8 +94,10 @@ proc prettyPrint*(infile, outfile: string, opt: PrettyOptions) =
   parser.em.indWidth = opt.indWidth
   if setupParser(parser, fileIdx, newIdentCache(), conf):
     parser.em.maxLineLen = opt.maxLineLen
-    discard parseAll(parser)
+    let fullAst = parseAll(parser)
     closeParser(parser)
+    when defined(nimpretty):
+      closeEmitter(parser.em, fullAst, finalCheck)
 
 proc main =
   var outfile, outdir: string
@@ -78,7 +116,13 @@ proc main =
   for kind, key, val in getopt():
     case kind
     of cmdArgument:
-      infiles.add(key.addFileExt(".nim"))
+      if dirExists(key):
+        for file in walkDirRec(key, skipSpecial = true):
+          if file.endsWith(".nim") or file.endsWith(".nimble"):
+            infiles.add(file)
+      else:
+        infiles.add(key.addFileExt(".nim"))
+
     of cmdLongOption, cmdShortOption:
       case normalize(key)
       of "help", "h": writeHelp()