summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorZahary Karadjov <zahary@gmail.com>2011-11-16 02:18:10 +0200
committerZahary Karadjov <zahary@gmail.com>2011-11-18 02:11:15 +0200
commit12bac28d23ab21879a0f40fc7b2b2c875be90f82 (patch)
tree4ae6e5029753950e3124ebcc0dff08575df50cfe
parentecd3c80e7eadbb3db9a8acdc3bb37b6f92e9b66b (diff)
downloadNim-12bac28d23ab21879a0f40fc7b2b2c875be90f82.tar.gz
macros and templates can be expanded anywhere where a type is expected.
This allows for various type selection algorithms to be implemented.
See tests / accept / compile / ttypeselectors.nim for examples.
-rwxr-xr-xcompiler/ast.nim6
-rwxr-xr-xcompiler/sem.nim20
-rwxr-xr-xcompiler/semexprs.nim43
-rwxr-xr-xcompiler/semstmts.nim15
-rwxr-xr-xcompiler/semtypes.nim30
-rw-r--r--tests/accept/compile/ttypeselectors.nim39
6 files changed, 118 insertions, 35 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index de22cbd7d..7f28d7b89 100755
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -633,6 +633,12 @@ proc copyNode*(src: PNode): PNode
 proc copyTree*(src: PNode): PNode
   # does copy its sons!
 
+const nkCallKinds* = {nkCall, nkInfix, nkPrefix, nkPostfix, nkCommand,
+                      nkCallStrLit}
+
+proc isCallExpr*(n: PNode): bool =
+  result = n.kind in nkCallKinds
+
 proc discardSons*(father: PNode)
 
 proc len*(n: PNode): int {.inline.} =
diff --git a/compiler/sem.nim b/compiler/sem.nim
index 78f7df940..a4a0ba1bc 100755
--- a/compiler/sem.nim
+++ b/compiler/sem.nim
@@ -64,6 +64,15 @@ proc ParamsTypeCheck(c: PContext, typ: PType) {.inline.} =
   if not typeAllowed(typ, skConst):
     GlobalError(typ.n.info, errXisNoType, typeToString(typ))
 
+proc expectMacroOrTemplateCall(c: PContext, n: PNode): PSym
+
+proc semTemplateExpr(c: PContext, n: PNode, s: PSym, semCheck = true): PNode
+
+proc semMacroExpr(c: PContext, n: PNode, sym: PSym, 
+                  semCheck: bool = true): PNode
+
+proc semWhen(c: PContext, n: PNode, semCheck: bool = true): PNode
+
 include semtempl
 
 proc semConstExpr(c: PContext, n: PNode): PNode = 
@@ -92,13 +101,16 @@ include seminst, semcall
 proc semAfterMacroCall(c: PContext, n: PNode, s: PSym): PNode = 
   result = n
   case s.typ.sons[0].kind
-  of tyExpr: 
+  of tyExpr:
     # BUGFIX: we cannot expect a type here, because module aliases would not 
     # work then (see the ``tmodulealias`` test)
     # semExprWithType(c, result)
-    result = semExpr(c, result) 
-  of tyStmt: result = semStmt(c, result)
-  of tyTypeDesc: result.typ = semTypeNode(c, result, nil)
+    result = semExpr(c, result)
+  of tyStmt:
+    result = semStmt(c, result)
+  of tyTypeDesc:
+    if n.kind == nkStmtList: result.kind = nkStmtListType
+    result.typ = semTypeNode(c, result, nil)
   else:
     result = semExpr(c, result)
     result = fitNode(c, s.typ.sons[0], result)
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index c1d2f1fec..af3048586 100755
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -400,10 +400,6 @@ proc isAssignable(c: PContext, n: PNode): TAssignableResult =
   else: 
     nil
 
-proc isCallExpr(n: PNode): bool = 
-  result = n.kind in {nkCall, nkInfix, nkPrefix, nkPostfix, nkCommand,
-                      nkCallStrLit}
-
 proc newHiddenAddrTaken(c: PContext, n: PNode): PNode = 
   if n.kind == nkHiddenDeref: 
     checkSonsLen(n, 1)
@@ -922,32 +918,35 @@ proc expectStringArg(c: PContext, n: PNode, i: int): PNode =
   if result.kind notin {nkStrLit, nkRStrLit, nkTripleStrLit}:
     GlobalError(result.info, errStringLiteralExpected)
 
-proc semExpandToAst(c: PContext, n: PNode, magicSym: PSym, 
-                    flags: TExprFlags): PNode =
-  if sonsLen(n) == 2:
-    if not isCallExpr(n.sons[1]):
-      GlobalError(n.info, errXisNoMacroOrTemplate, n.renderTree)
+proc expectMacroOrTemplateCall(c: PContext, n: PNode): PSym =
+  ## The argument to the proc should be nkCall(...) or similar
+  ## Returns the macro/template symbol
+  if not isCallExpr(n):
+    GlobalError(n.info, errXisNoMacroOrTemplate, n.renderTree)
 
-    var macroCall = n.sons[1]
+  var expandedSym = qualifiedLookup(c, n[0], {checkUndeclared})
+  if expandedSym == nil:
+    GlobalError(n.info, errUndeclaredIdentifier, n[0].renderTree)
 
-    var expandedSym = qualifiedLookup(c, macroCall.sons[0], {checkUndeclared})
-    if expandedSym == nil:
-      GlobalError(n.info, errUndeclaredIdentifier, macroCall[0].renderTree)
+  if expandedSym.kind notin {skMacro, skTemplate}:
+    GlobalError(n.info, errXisNoMacroOrTemplate, expandedSym.name.s)
 
-    if expandedSym.kind notin {skMacro, skTemplate}:
-      GlobalError(n.info, errXisNoMacroOrTemplate, expandedSym.name.s)
+  result = expandedSym
 
-    macroCall.sons[0] = newNodeI(nkSym, macroCall.info)
-    macroCall.sons[0].sym = expandedSym
+proc semExpandToAst(c: PContext, n: PNode, magicSym: PSym,
+                    flags: TExprFlags): PNode =
+  if sonsLen(n) == 2:
+    var macroCall = n[1]
+    var expandedSym = expectMacroOrTemplateCall(c, macroCall)
+
+    macroCall.sons[0] = newSymNode(expandedSym, macroCall.info)
     markUsed(n, expandedSym)
 
     for i in countup(1, macroCall.len-1):
-      macroCall.sons[i] = semExprWithType(c, macroCall.sons[i], {efAllowType})
+      macroCall.sons[i] = semExprWithType(c, macroCall[i], {efAllowType})
 
-    # Preserve the magic symbol in order to handled in evals.nim
-    n.sons[0] = newNodeI(nkSym, n.info)
-    n.sons[0].sym = magicSym
-    
+    # Preserve the magic symbol in order to be handled in evals.nim
+    n.sons[0] = newSymNode(magicSym, n.info)
     n.typ = expandedSym.getReturnType
     result = n
   else:
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 444f55883..8412d4783 100755
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -12,8 +12,15 @@
 proc semCommand(c: PContext, n: PNode): PNode =
   result = semExprNoType(c, n)
 
-proc semWhen(c: PContext, n: PNode): PNode = 
+proc semWhen(c: PContext, n: PNode, semCheck = true): PNode =
+  # If semCheck is set to false, ``when`` will return the verbatim AST of
+  # the correct branch. Otherwise the AST will be passed through semStmt.
   result = nil
+  
+  template set_result(e: expr) =
+    if semCheck: result = semStmt(c, e) # do not open a new scope!
+    else: result = e
+
   for i in countup(0, sonsLen(n) - 1): 
     var it = n.sons[i]
     case it.kind
@@ -21,12 +28,12 @@ proc semWhen(c: PContext, n: PNode): PNode =
       checkSonsLen(it, 2)
       var e = semAndEvalConstExpr(c, it.sons[0])
       if (e.kind != nkIntLit): InternalError(n.info, "semWhen")
-      if (e.intVal != 0) and (result == nil): 
-        result = semStmt(c, it.sons[1]) # do not open a new scope!
+      if (e.intVal != 0) and (result == nil):
+        set_result(it.sons[1]) 
     of nkElse: 
       checkSonsLen(it, 1)
       if result == nil: 
-        result = semStmt(c, it.sons[0]) # do not open a new scope!
+        set_result(it.sons[0])
     else: illFormedAst(n)
   if result == nil: 
     result = newNodeI(nkNilLit, n.info) 
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index 1d855d97f..e51e02f2d 100755
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -569,16 +569,16 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode,
   #if matchType(result, [(tyProc, 1), (tyVar, 0)], tyGenericInvokation):
   #  debug result
 
-proc semStmtListType(c: PContext, n: PNode, prev: PType): PType = 
+proc semStmtListType(c: PContext, n: PNode, prev: PType): PType =
   checkMinSonsLen(n, 1)
   var length = sonsLen(n)
-  for i in countup(0, length - 2): 
+  for i in countup(0, length - 2):
     n.sons[i] = semStmt(c, n.sons[i])
-  if length > 0: 
+  if length > 0:
     result = semTypeNode(c, n.sons[length - 1], prev)
     n.typ = result
     n.sons[length - 1].typ = result
-  else: 
+  else:
     result = nil
   
 proc semBlockType(c: PContext, n: PNode, prev: PType): PType = 
@@ -630,6 +630,18 @@ proc semGeneric(c: PContext, n: PNode, s: PSym, prev: PType): PType =
     if s.ast == nil: GlobalError(n.info, errCannotInstantiateX, s.name.s)
     result = instGenericContainer(c, n, result)
 
+proc semExpandToType(c: PContext, n: PNode, sym: PSym): PType =
+  # Expands a macro or template until a type is returned
+  # results in GlobalError if the macro expands to something different
+  markUsed(n, sym)
+  case sym.kind
+  of skMacro:
+    result = semTypeNode(c, semMacroExpr(c, n, sym), nil)
+  of skTemplate:
+    result = semTypeNode(c, semTemplateExpr(c, n, sym), nil)
+  else:
+    GlobalError(n.info, errXisNoMacroOrTemplate, n.renderTree)
+
 proc semTypeNode(c: PContext, n: PNode, prev: PType): PType = 
   result = nil
   if gCmd == cmdIdeTools: suggestExpr(c, n)
@@ -642,7 +654,15 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
   of nkPar: 
     if sonsLen(n) == 1: result = semTypeNode(c, n.sons[0], prev)
     else: GlobalError(n.info, errTypeExpected)
-  of nkBracketExpr: 
+  of nkCallKinds:
+    # expand macros and templates
+    var expandedSym = expectMacroOrTemplateCall(c, n)
+    result = semExpandToType(c, n, expandedSym)
+  of nkWhenStmt:
+    var whenResult = semWhen(c, n, false)
+    if whenResult.kind == nkStmtList: whenResult.kind = nkStmtListType
+    result = semTypeNode(c, whenResult, prev)
+  of nkBracketExpr:
     checkMinSonsLen(n, 2)
     var s = semTypeIdent(c, n.sons[0])
     case s.magic
diff --git a/tests/accept/compile/ttypeselectors.nim b/tests/accept/compile/ttypeselectors.nim
new file mode 100644
index 000000000..1cc4b02b7
--- /dev/null
+++ b/tests/accept/compile/ttypeselectors.nim
@@ -0,0 +1,39 @@
+import macros
+
+template selectType(x: int): typeDesc =
+  when x < 10:
+    int
+  else:
+    string
+
+template simpleTypeTempl: typeDesc =
+  string
+
+macro typeFromMacro(s: expr): typeDesc =
+  result = newNimNode(nnkIdent)
+  result.ident = !"string"
+  # result = newIdentNode"string"
+  
+proc t1*(x: int): simpleTypeTempl() =
+  result = "test"
+
+proc t2*(x: int): selectType(100) =
+  result = "test"
+
+proc t3*(x: int): selectType(1) =
+  result = 10
+
+proc t4*(x: int): typeFromMacro() =
+  result = "test"
+
+var x*: selectType(50) = "test"
+
+proc t5*(x: selectType(5)) =
+  var y = x + 10
+  echo y
+
+var y*: type(t2(100)) = "test"
+
+proc t6*(x: type(t3(0))): type(t1(0)) =
+  result = $x
+