about summary refs log tree commit diff stats
path: root/src/ips/socketstream.nim
blob: 5a030427c82b1b813f070c8eb05c40d7678ba0b5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import nativesockets
import net
import os
import streams

when defined(posix):
  import posix

import io/posixstream
import ips/serversocket

type SocketStream* = ref object of Stream
  source*: Socket
  blk*: bool
  isend: bool

proc sockReadData(s: Stream, buffer: pointer, len: int): int =
  assert len != 0
  let s = SocketStream(s)
  if s.blk:
    while result < len:
      let n = s.source.recv(cast[pointer](cast[int](buffer) + result), len - result)
      if n < 0:
        if result == 0:
          result = n
        break
      elif n == 0:
        s.isend = true
        break
      result += n
  else:
    result = s.source.recv(buffer, len)
  if result == 0:
    s.isend = true
    raise newException(EOFError, "eof")
  if result < 0:
    raisePosixIOError()
  elif result == 0:
    s.isend = true

proc sockWriteData(s: Stream, buffer: pointer, len: int) =
  #TODO maybe don't block if blk is false?
  var i = 0
  while i < len:
    let n = SocketStream(s).source.send(cast[pointer](cast[int](buffer) + i), len - i)
    if n < 0:
      raise newException(IOError, $strerror(errno))
    i += n

proc sockAtEnd(s: Stream): bool =
  SocketStream(s).isend

proc sockClose(s: Stream) = {.cast(tags: []).}: #...sigh
  let s = SocketStream(s)
  s.source.close()

# See https://stackoverflow.com/a/4491203
proc sendFileHandle*(s: SocketStream, fd: FileHandle) =
  var hdr: Tmsghdr
  var iov: IOVec
  var space: csize_t
  {.emit: [
  space, """ = CMSG_SPACE(sizeof(int));""",
  ].}
  var cmsgbuf = alloc(cast[int](space))
  var buf = char(0)
  iov.iov_base = addr buf
  iov.iov_len = csize_t(1)
  zeroMem(addr hdr, sizeof(hdr))
  hdr.msg_iov = addr iov
  hdr.msg_iovlen = 1
  hdr.msg_control = cmsgbuf
  # ...sigh
  {.emit: [
  hdr.msg_controllen, """ = CMSG_LEN(sizeof(int));""",
  ].}
  let cmsg = CMSG_FIRSTHDR(addr hdr)
  # FileHandle is cint, so sizeof(FileHandle) in c is sizeof(int).
  when sizeof(FileHandle) != sizeof(cint):
    error("Or not...")
  {.emit: [
  cmsg.cmsg_len, """ = CMSG_LEN(sizeof(int));"""
  ].}
  cmsg.cmsg_level = SOL_SOCKET
  cmsg.cmsg_type = SCM_RIGHTS
  cast[ptr FileHandle](CMSG_DATA(cmsg))[] = fd
  let n = sendmsg(s.source.getFd(), addr hdr, 0)
  dealloc(cmsgbuf)
  assert n == int(iov.iov_len) #TODO remove this

proc recvFileHandle*(s: SocketStream): FileHandle =
  var iov: IOVec
  var hdr: Tmsghdr
  var buf: char
  var cmsgbuf = alloc(CMSG_SPACE(csize_t(sizeof(FileHandle))))
  iov.iov_base = addr buf
  iov.iov_len = 1
  zeroMem(addr hdr, sizeof(hdr))
  hdr.msg_iov = addr iov
  hdr.msg_iovlen = 1
  hdr.msg_control = cmsgbuf
  {.emit: [
  hdr.msg_controllen, """ = CMSG_SPACE(sizeof(int));"""
  ].}
  let n = recvmsg(s.source.getFd(), addr hdr, 0)
  assert n != 0, "Unexpected EOF" #TODO remove this
  assert n > 0, "Failed to receive message " & $osLastError() #TODO remove this
  var cmsg = CMSG_FIRSTHDR(addr hdr)
  result = cast[ptr FileHandle](CMSG_DATA(cmsg))[]
  dealloc(cmsgbuf)

func newSocketStream*(): SocketStream =
  new(result)
  result.readDataImpl = cast[proc (s: Stream, buffer: pointer, bufLen: int): int
      {.nimcall, raises: [Defect, IOError, OSError], tags: [ReadIOEffect], gcsafe.}
  ](sockReadData) # ... ???
  result.writeDataImpl = sockWriteData
  result.atEndImpl = sockAtEnd
  result.closeImpl = sockClose

proc connectSocketStream*(path: string, buffered = true, blocking = true): SocketStream =
  result = newSocketStream()
  result.blk = blocking
  let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered)
  if not blocking:
    sock.getFd().setBlocking(false)
  connectUnix(sock, path)
  result.source = sock

proc connectSocketStream*(pid: Pid, buffered = true, blocking = true): SocketStream =
  try:
    connectSocketStream(getSocketPath(pid), buffered, blocking)
  except OSError:
    return nil

proc acceptSocketStream*(ssock: ServerSocket, blocking = true): SocketStream =
  result = newSocketStream()
  result.blk = blocking
  var sock: Socket
  ssock.sock.accept(sock, inheritable = true)
  result.source = sock