about summary refs log tree commit diff stats
path: root/src/io/socketstream.nim
blob: 5744ad32fe0ff721af2833c307661a4ab151d86d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import std/nativesockets
import std/net
import std/os

import io/dynstream
import io/posixstream
import io/serversocket

type SocketStream* = ref object of PosixStream
  source*: Socket

method recvData*(s: SocketStream; buffer: pointer; len: int): int =
  let n = s.source.recv(buffer, len)
  if n < 0:
    raisePosixIOError()
  if n == 0:
    if unlikely(s.isend):
      raise newException(EOFError, "eof")
    s.isend = true
  return n

method sendData*(s: SocketStream; buffer: pointer; len: int): int =
  let n = s.source.send(buffer, len)
  if n < 0:
    raisePosixIOError()
  return n

{.compile: "sendfd.c".}
proc sendfd(sock, fd: cint): int {.importc.}

proc sendFileHandle*(s: SocketStream; fd: FileHandle) =
  assert not s.source.hasDataBuffered
  let n = sendfd(s.fd, cint(fd))
  if n < 0:
    raisePosixIOError()
  assert n == 1 # we send a single nul byte as buf

{.compile: "recvfd.c".}
proc recvfd(sock: cint; fdout: ptr cint): int {.importc.}

proc recvFileHandle*(s: SocketStream): FileHandle =
  assert not s.source.hasDataBuffered
  var fd: cint
  let n = recvfd(s.fd, addr fd)
  if n < 0:
    raisePosixIOError()
  return FileHandle(fd)

method setBlocking*(s: SocketStream; blocking: bool) =
  s.blocking = blocking
  s.source.getFd().setBlocking(blocking)

method seek*(s: SocketStream; off: int) =
  doAssert false

method sclose*(s: SocketStream) =
  s.source.close()

# see serversocket.nim for an explanation
{.compile: "connect_unix.c".}
proc connect_unix_from_c(fd: cint; path: cstring; pathlen: cint): cint
  {.importc.}
when defined(freebsd):
  # for FreeBSD/capsicum
  proc connectat_unix_from_c(baseFd, sockFd: cint; rel_path: cstring;
    rel_pathlen: cint): cint {.importc.}

proc connectAtSocketStream0(socketDir: string; baseFd, pid: int;
    blocking = true): SocketStream =
  let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM,
    Protocol.IPPROTO_IP, buffered = false)
  if not blocking:
    sock.getFd().setBlocking(false)
  let path = getSocketPath(socketDir, pid)
  if baseFd == -1:
    if connect_unix_from_c(cint(sock.getFd()), cstring(path),
        cint(path.len)) != 0:
      raiseOSError(osLastError())
  else:
    when defined(freebsd):
      doAssert baseFd != -1
      let name = getSocketName(pid)
      if connectat_unix_from_c(cint(baseFd), cint(sock.getFd()), cstring(name),
          cint(name.len)) != 0:
        raiseOSError(osLastError())
    else:
      # shouldn't have sockDirFd on other architectures
      doAssert false
  return SocketStream(
    source: sock,
    fd: cint(sock.getFd()),
    blocking: blocking
  )

proc connectSocketStream*(socketDir: string; baseFd, pid: int;
    blocking = true): SocketStream =
  try:
    return connectAtSocketStream0(socketDir, baseFd, pid, blocking)
  except OSError:
    return nil

proc acceptSocketStream*(ssock: ServerSocket; blocking = true): SocketStream =
  var sock: Socket
  ssock.sock.accept(sock, inheritable = true)
  if not blocking:
    sock.getFd().setBlocking(false)
  return SocketStream(
    blocking: blocking,
    source: sock,
    fd: cint(sock.getFd())
  )