[lxc-devel] [lxd/master] [RFC] proxy: port ranges

brauner on Github lxc-bot at linuxcontainers.org
Mon Jun 18 12:50:56 UTC 2018


A non-text attachment was scrubbed...
Name: not available
Type: text/x-mailbox
Size: 853 bytes
Desc: not available
URL: <http://lists.linuxcontainers.org/pipermail/lxc-devel/attachments/20180618/6aa4c0a5/attachment.bin>
-------------- next part --------------
From 4fcd67fccbd91d81df3ae1bc21bbf887c8ed265d Mon Sep 17 00:00:00 2001
From: Christian Brauner <christian.brauner at ubuntu.com>
Date: Mon, 18 Jun 2018 14:32:28 +0200
Subject: [PATCH 1/4] reader: handle EINTR

Signed-off-by: Christian Brauner <christian.brauner at ubuntu.com>
---
 shared/eagain/file.go | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/shared/eagain/file.go b/shared/eagain/file.go
index 9e3eac9c0..2739969df 100644
--- a/shared/eagain/file.go
+++ b/shared/eagain/file.go
@@ -22,7 +22,7 @@ again:
 
 	// keep retrying on EAGAIN
 	errno, ok := shared.GetErrno(err)
-	if ok && errno == syscall.EAGAIN {
+	if ok && (errno == syscall.EAGAIN || errno == syscall.EINTR) {
 		goto again
 	}
 
@@ -44,7 +44,7 @@ again:
 
 	// keep retrying on EAGAIN
 	errno, ok := shared.GetErrno(err)
-	if ok && errno == syscall.EAGAIN {
+	if ok && (errno == syscall.EAGAIN || errno == syscall.EINTR) {
 		goto again
 	}
 

From 9302ce50bce392129ccd722f4cf2306541591621 Mon Sep 17 00:00:00 2001
From: Christian Brauner <christian.brauner at ubuntu.com>
Date: Sat, 16 Jun 2018 13:09:18 +0200
Subject: [PATCH 2/4] proxy: genericize to handle multiple ports

Closes #4601.

Signed-off-by: Christian Brauner <christian.brauner at ubuntu.com>
---
 lxd/main_forkproxy.go | 171 ++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 131 insertions(+), 40 deletions(-)

diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go
index 074bb8c72..93a367718 100644
--- a/lxd/main_forkproxy.go
+++ b/lxd/main_forkproxy.go
@@ -6,6 +6,7 @@ import (
 	"net"
 	"os"
 	"os/signal"
+	"strconv"
 	"strings"
 	"syscall"
 	"time"
@@ -260,7 +261,7 @@ type cmdForkproxy struct {
 
 type proxyAddress struct {
 	connType string
-	addr     string
+	addr     []string
 	abstract bool
 }
 
@@ -307,45 +308,71 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 		return fmt.Errorf("Failed to call forkproxy constructor")
 	}
 
-	lAddr := parseAddr(listenAddr)
+	lAddr, err := parseAddr(listenAddr)
+	if err != nil {
+		return err
+	}
 
 	if C.whoami == C.FORKPROXY_CHILD {
-		err := os.Remove(lAddr.addr)
-		if err != nil && !os.IsNotExist(err) {
-			return err
+		if lAddr.connType == "unix"  && !lAddr.abstract {
+			err := os.Remove(lAddr.addr[0])
+			if err != nil && !os.IsNotExist(err) {
+				return err
+			}
 		}
 
-		file, err := getListenerFile(listenAddr)
-		if err != nil {
-			return err
+		for _, port := range lAddr.addr {
+			fmt.Println(port)
+		}
+
+		for _, addr := range lAddr.addr {
+			file, err := getListenerFile(lAddr.connType, addr)
+			if err != nil {
+				return err
+			}
+
+			err = shared.AbstractUnixSendFd(forkproxyUDSSockFDNum, int(file.Fd()))
+			file.Close()
+			if err != nil {
+				break
+			}
 		}
 
-		err = shared.AbstractUnixSendFd(forkproxyUDSSockFDNum, int(file.Fd()))
 		syscall.Close(forkproxyUDSSockFDNum)
-		file.Close()
 		return err
 	}
 
-	file, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum)
-	syscall.Close(forkproxyUDSSockFDNum)
-	if err != nil {
-		fmt.Printf("Failed to receive fd from listener process: %v\n", err)
-		return err
+	files := []*os.File{}
+	for range lAddr.addr {
+		f, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum)
+		if err != nil {
+			fmt.Printf("Failed to receive fd from listener process: %v\n", err)
+			return err
+		}
+		files = append(files, f)
 	}
+	syscall.Close(forkproxyUDSSockFDNum)
 
 	var srcConn net.Conn
-	var listener net.Listener
+	var listeners []*net.Listener
 
 	udpFD := -1
 	if lAddr.connType == "udp" {
-		udpFD = int(file.Fd())
-		srcConn, err = net.FileConn(file)
+		udpFD = int(files[0].Fd())
+		srcConn, err = net.FileConn(files[0])
+		if err != nil {
+			fmt.Printf("Failed to re-assemble listener: %v", err)
+			return err
+		}
 	} else {
-		listener, err = net.FileListener(file)
-	}
-	if err != nil {
-		fmt.Printf("Failed to re-assemble listener: %v", err)
-		return err
+		for _, f := range files {
+			listener, err := net.FileListener(f)
+			if err != nil {
+				fmt.Printf("Failed to re-assemble listener: %v", err)
+				return err
+			}
+			listeners = append(listeners, &listener)
+		}
 	}
 
 	// Handle SIGTERM which is sent when the proxy is to be removed
@@ -358,33 +385,44 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	go func() {
 		<-sigs
 		terminate = true
-		file.Close()
+
+		for _, f := range files {
+			f.Close()
+		}
+
 		if lAddr.connType == "udp" {
 			srcConn.Close()
 			// Kill ourselves since we will otherwise block on UDP
 			// connect() or poll().
 			syscall.Kill(killOnUDP, syscall.SIGKILL)
 		} else {
-			listener.Close()
+			for _, listener := range listeners {
+				(*listener).Close()
+			}
 		}
 	}()
 
 	connectAddr := args[3]
-	cAddr := parseAddr(connectAddr)
+	cAddr, err := parseAddr(connectAddr)
+	if err != nil {
+		return err
+	}
 
 	if cAddr.connType == "unix" && !cAddr.abstract {
 		// Create socket
-		file, err := getListenerFile(fmt.Sprintf("unix:%s", cAddr.addr))
+		file, err := getListenerFile("unix", cAddr.addr[0])
 		if err != nil {
 			return err
 		}
 		file.Close()
 
-		defer os.Remove(cAddr.addr)
+		if cAddr.connType == "unix" && !cAddr.abstract {
+			defer os.Remove(cAddr.addr[0])
+		}
 	}
 
 	if lAddr.connType == "unix" && !lAddr.abstract {
-		defer os.Remove(lAddr.addr)
+		defer os.Remove(lAddr.addr[0])
 	}
 
 	fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType)
@@ -418,7 +456,7 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 		// begin proxying
 		for {
 			// Accept a new client
-			srcConn, err = listener.Accept()
+			srcConn, err = (*listeners[0]).Accept()
 			if err != nil {
 				if terminate {
 					break
@@ -619,15 +657,12 @@ func tryListenUDP(protocol string, addr string) (*os.File, error) {
 	return file, err
 }
 
-func getListenerFile(listenAddr string) (*os.File, error) {
-	fields := strings.SplitN(listenAddr, ":", 2)
-	addr := strings.Join(fields[1:], "")
-
-	if fields[0] == "udp" {
-		return tryListenUDP(fields[0], addr)
+func getListenerFile(protocol string, addr string) (*os.File, error) {
+	if protocol == "udp" {
+		return tryListenUDP("udp", addr)
 	}
 
-	listener, err := tryListen(fields[0], addr)
+	listener, err := tryListen(protocol, addr)
 	if err != nil {
 		return nil, fmt.Errorf("Failed to listen on %s: %v", addr, err)
 	}
@@ -654,11 +689,67 @@ func getDestConn(connectAddr string) (net.Conn, error) {
 	return net.Dial(fields[0], addr)
 }
 
-func parseAddr(addr string) *proxyAddress {
+func parsePortRange(r string) (int64, int64, error) {
+	entries := strings.Split(r, "-")
+	if len(entries) > 2 {
+		return -1, -1, fmt.Errorf("Invalid port range %s", r)
+	}
+
+	base, err := strconv.ParseInt(entries[0], 10, 64)
+	if err != nil {
+		return -1, -1, err
+	}
+
+	size := int64(1)
+	if len(entries) > 1 {
+		size, err = strconv.ParseInt(entries[1], 10, 64)
+		if err != nil {
+			return -1, -1, err
+		}
+
+		size -= base
+		size += 1
+	}
+
+	return base, size, nil
+}
+
+func parseAddr(addr string) (*proxyAddress, error) {
+	// Split into <protocol> and <address>
 	fields := strings.SplitN(addr, ":", 2)
-	return &proxyAddress{
+
+	newProxyAddr := &proxyAddress{
 		connType: fields[0],
-		addr:     fields[1],
 		abstract: strings.HasPrefix(fields[1], "@"),
 	}
+
+	// unix addresses cannot have ports
+	if newProxyAddr.connType == "unix" {
+		newProxyAddr.addr = []string{fields[1]}
+		return newProxyAddr, nil
+	}
+
+	// Split <address> into <address> and <ports>
+	addrParts := strings.SplitN(fields[1], ":", 2)
+	// no ports
+	if len(addrParts) == 1 {
+		newProxyAddr.addr = []string{fields[1]}
+		return newProxyAddr, nil
+	}
+
+	// Split <ports> into individual ports and port ranges
+	ports := strings.SplitN(addrParts[1], ",", -1)
+	for _, port := range ports {
+		portFirst, portRange, err := parsePortRange(port)
+		if err != nil {
+			return nil, err
+		}
+
+		for i := int64(0); i < portRange; i++ {
+			newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst + i)
+			newProxyAddr.addr = append(newProxyAddr.addr, newAddr)
+		}
+	}
+
+	return newProxyAddr, nil
 }

From fd58089ce0ef92f79b0116eff5b2dfb616a219cc Mon Sep 17 00:00:00 2001
From: Christian Brauner <christian.brauner at ubuntu.com>
Date: Sat, 16 Jun 2018 14:28:51 +0200
Subject: [PATCH 3/4] proxy: handle UDP and TCP port ranges

Closes #4601.

Signed-off-by: Christian Brauner <christian.brauner at ubuntu.com>
---
 lxd/main_forkproxy.go | 214 ++++++++++++++++++++++++++++++--------------------
 1 file changed, 128 insertions(+), 86 deletions(-)

diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go
index 93a367718..946a87450 100644
--- a/lxd/main_forkproxy.go
+++ b/lxd/main_forkproxy.go
@@ -10,6 +10,7 @@ import (
 	"strings"
 	"syscall"
 	"time"
+	"unsafe"
 
 	"github.com/spf13/cobra"
 
@@ -25,6 +26,7 @@ import (
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
+#include <sys/epoll.h>
 #include <sys/socket.h>
 #include <sys/stat.h>
 #include <sys/types.h>
@@ -283,6 +285,55 @@ func (c *cmdForkproxy) Command() *cobra.Command {
 	return cmd
 }
 
+func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr string, udpSrcConn *net.Conn, listener *net.Listener) error {
+	fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType)
+	if lAddr.connType == "udp" {
+		go func() error {
+			// Connect to the target
+			dstConn, err := getDestConn(connectAddr)
+			if err != nil {
+				fmt.Printf("Error: Failed to connect to target: %v\n", err)
+				(*udpSrcConn).Close()
+				return err
+			}
+
+			genericRelay((*udpSrcConn), dstConn, false)
+
+			return nil
+		}()
+
+		return nil
+	}
+
+	// Accept a new client
+	srcConn, err := (*listener).Accept()
+	if err != nil {
+		fmt.Printf("Error: Failed to accept new connection: %v\n", err)
+		return err
+	}
+	fmt.Printf("Accepted a new connection\n")
+
+	// Connect to the target
+	dstConn, err := getDestConn(connectAddr)
+	if err != nil {
+		fmt.Printf("Error: Failed to connect to target: %v\n", err)
+		if lAddr.connType != "udp" {
+			srcConn.Close()
+		}
+
+		return err
+	}
+
+	if cAddr.connType == "unix" && lAddr.connType == "unix" {
+		// Handle OOB if both src and dst are using unix sockets
+		go unixRelay(srcConn, dstConn)
+	} else {
+		go genericRelay(srcConn, dstConn, true)
+	}
+
+	return nil
+}
+
 func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	// Only root should run this
 	if os.Geteuid() != 0 {
@@ -314,17 +365,13 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	}
 
 	if C.whoami == C.FORKPROXY_CHILD {
-		if lAddr.connType == "unix"  && !lAddr.abstract {
+		if lAddr.connType == "unix" && !lAddr.abstract {
 			err := os.Remove(lAddr.addr[0])
 			if err != nil && !os.IsNotExist(err) {
 				return err
 			}
 		}
 
-		for _, port := range lAddr.addr {
-			fmt.Println(port)
-		}
-
 		for _, addr := range lAddr.addr {
 			file, err := getListenerFile(lAddr.connType, addr)
 			if err != nil {
@@ -347,6 +394,7 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 		f, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum)
 		if err != nil {
 			fmt.Printf("Failed to receive fd from listener process: %v\n", err)
+			syscall.Close(forkproxyUDSSockFDNum)
 			return err
 		}
 		files = append(files, f)
@@ -354,54 +402,36 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	syscall.Close(forkproxyUDSSockFDNum)
 
 	var srcConn net.Conn
-	var listeners []*net.Listener
+	var listenerMap map[int]*net.Listener
+	var udpConnMap map[int]*net.Conn
 
-	udpFD := -1
-	if lAddr.connType == "udp" {
-		udpFD = int(files[0].Fd())
-		srcConn, err = net.FileConn(files[0])
-		if err != nil {
-			fmt.Printf("Failed to re-assemble listener: %v", err)
-			return err
+	isUDPListener := lAddr.connType == "udp"
+	if isUDPListener {
+		udpConnMap = make(map[int]*net.Conn, len(lAddr.addr))
+		for _, f := range files {
+			srcConn, err = net.FileConn(files[0])
+			if err != nil {
+				fmt.Printf("Failed to re-assemble listener: %v", err)
+				return err
+			}
+			udpConnMap[int(f.Fd())] = &srcConn
 		}
 	} else {
+		listenerMap = make(map[int]*net.Listener, len(lAddr.addr))
 		for _, f := range files {
 			listener, err := net.FileListener(f)
 			if err != nil {
 				fmt.Printf("Failed to re-assemble listener: %v", err)
 				return err
 			}
-			listeners = append(listeners, &listener)
+			listenerMap[int(f.Fd())] = &listener
 		}
 	}
 
 	// Handle SIGTERM which is sent when the proxy is to be removed
-	terminate := false
 	sigs := make(chan os.Signal, 1)
 	signal.Notify(sigs, syscall.SIGTERM)
 
-	// Wait for SIGTERM and close the listener in order to exit the loop below
-	killOnUDP := syscall.Getpid()
-	go func() {
-		<-sigs
-		terminate = true
-
-		for _, f := range files {
-			f.Close()
-		}
-
-		if lAddr.connType == "udp" {
-			srcConn.Close()
-			// Kill ourselves since we will otherwise block on UDP
-			// connect() or poll().
-			syscall.Kill(killOnUDP, syscall.SIGKILL)
-		} else {
-			for _, listener := range listeners {
-				(*listener).Close()
-			}
-		}
-	}()
-
 	connectAddr := args[3]
 	cAddr, err := parseAddr(connectAddr)
 	if err != nil {
@@ -425,70 +455,82 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 		defer os.Remove(lAddr.addr[0])
 	}
 
-	fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType)
-	if lAddr.connType == "udp" {
-		for {
-			ret, revents, err := shared.GetPollRevents(udpFD, -1, (shared.POLLIN | shared.POLLPRI | shared.POLLERR | shared.POLLHUP | shared.POLLRDHUP | shared.POLLNVAL))
-			if ret < 0 {
-				fmt.Printf("Failed to poll on file descriptor: %s\n", err)
-				srcConn.Close()
-				return err
-			}
+	epFD := C.epoll_create1(C.EPOLL_CLOEXEC)
+	if epFD < 0 {
+		return fmt.Errorf("Failed to create new epoll instance")
+	}
 
-			if (revents & (shared.POLLERR | shared.POLLHUP | shared.POLLRDHUP | shared.POLLNVAL)) > 0 {
-				err := fmt.Errorf("Invalid UDP socket file descriptor")
-				fmt.Printf("%s\n", err)
-				srcConn.Close()
-				return err
-			}
+	// Wait for SIGTERM and close the listener in order to exit the loop below
+	self := syscall.Getpid()
+	go func() {
+		<-sigs
 
-			// Connect to the target
-			dstConn, err := getDestConn(connectAddr)
-			if err != nil {
-				fmt.Printf("error: Failed to connect to target: %v\n", err)
-				srcConn.Close()
-				return err
+		for _, f := range files {
+			C.epoll_ctl(epFD, C.EPOLL_CTL_DEL, C.int(f.Fd()), nil)
+			f.Close()
+		}
+		syscall.Close(int(epFD))
+
+		if isUDPListener {
+			for _, l := range udpConnMap {
+				(*l).Close()
+			}
+		} else {
+			for _, l := range listenerMap {
+				(*l).Close()
 			}
+		}
+		syscall.Kill(self, syscall.SIGKILL)
+	}()
+	defer syscall.Kill(self, syscall.SIGTERM)
 
-			genericRelay(srcConn, dstConn, false)
+	for _, f := range files {
+		var ev C.struct_epoll_event
+		ev.events = C.EPOLLIN
+		if isUDPListener {
+			ev.events |= C.EPOLLET
 		}
-	} else {
-		// begin proxying
-		for {
-			// Accept a new client
-			srcConn, err = (*listeners[0]).Accept()
-			if err != nil {
-				if terminate {
-					break
-				}
 
-				fmt.Printf("error: Failed to accept new connection: %v\n", err)
-				continue
-			}
-			fmt.Printf("Accepted a new connection\n")
+		*(*C.int)(unsafe.Pointer(uintptr(unsafe.Pointer(&ev)) + unsafe.Sizeof(ev.events))) = C.int(f.Fd())
+		ret := C.epoll_ctl(epFD, C.EPOLL_CTL_ADD, C.int(f.Fd()), &ev)
+		if ret < 0 {
+			return fmt.Errorf("Failed to add listener fd to epoll instance")
+		}
+		fmt.Printf("Added listener socket file descriptor %d to epoll instance\n", int(f.Fd()))
+	}
 
-			// Connect to the target
-			dstConn, err := getDestConn(connectAddr)
-			if err != nil {
-				fmt.Printf("error: Failed to connect to target: %v\n", err)
-				if lAddr.connType != "udp" {
-					srcConn.Close()
-				}
+	for {
+		var events [10]C.struct_epoll_event
+
+		nfds := C.epoll_wait(epFD, &events[0], 10, -1)
+		if nfds < 0 {
+			fmt.Printf("Failed to wait on epoll instance")
+			break
+		}
 
+		for i := C.int(0); i < nfds; i++ {
+			var listener *net.Listener
+			var udpListener *net.Conn
+			var ok bool
+
+			curFD := *(*C.int)(unsafe.Pointer(uintptr(unsafe.Pointer(&events[i])) + unsafe.Sizeof(events[i].events)))
+			if isUDPListener {
+				udpListener, ok = udpConnMap[int(curFD)]
+			} else {
+				listener, ok = listenerMap[int(curFD)]
+			}
+			if !ok {
 				continue
 			}
 
-			if cAddr.connType == "unix" && lAddr.connType == "unix" {
-				// Handle OOB if both src and dst are using unix sockets
-				go unixRelay(srcConn, dstConn)
-			} else {
-				go genericRelay(srcConn, dstConn, true)
+			err := listenerInstance(lAddr, cAddr, connectAddr, udpListener, listener)
+			if err != nil {
+				fmt.Printf("Failed to prepare new listener instance: %s", err)
 			}
 		}
 	}
 
 	fmt.Printf("Stopping proxy\n")
-
 	return nil
 }
 
@@ -746,7 +788,7 @@ func parseAddr(addr string) (*proxyAddress, error) {
 		}
 
 		for i := int64(0); i < portRange; i++ {
-			newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst + i)
+			newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst+i)
 			newProxyAddr.addr = append(newProxyAddr.addr, newAddr)
 		}
 	}

From 01f19eb8c9309f3b16613dbe143eed66a617d678 Mon Sep 17 00:00:00 2001
From: Christian Brauner <christian.brauner at ubuntu.com>
Date: Mon, 18 Jun 2018 11:46:28 +0200
Subject: [PATCH 4/4] proxy: dump traffic to the same connection

The old implementation used to call connect() every time a new client got
accepted. Iiuc, this is not what we want. Ideally, we'd want all clients to
dump their traffic to the same connect()ion. This is especially true when we
are forwarding multiple ports. Unfortunately, this makes the actual
implementation more complex.
In any case, I might be mistaken and what we want is that each new accepted
client on the forwarded port also causes a new connect() call.

Closes #4601.

Signed-off-by: Christian Brauner <christian.brauner at ubuntu.com>
---
 lxd/main_forkproxy.go | 228 ++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 165 insertions(+), 63 deletions(-)

diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go
index 946a87450..7c3d40ccf 100644
--- a/lxd/main_forkproxy.go
+++ b/lxd/main_forkproxy.go
@@ -8,6 +8,7 @@ import (
 	"os/signal"
 	"strconv"
 	"strings"
+	"sync"
 	"syscall"
 	"time"
 	"unsafe"
@@ -285,22 +286,10 @@ func (c *cmdForkproxy) Command() *cobra.Command {
 	return cmd
 }
 
-func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr string, udpSrcConn *net.Conn, listener *net.Listener) error {
-	fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType)
-	if lAddr.connType == "udp" {
-		go func() error {
-			// Connect to the target
-			dstConn, err := getDestConn(connectAddr)
-			if err != nil {
-				fmt.Printf("Error: Failed to connect to target: %v\n", err)
-				(*udpSrcConn).Close()
-				return err
-			}
-
-			genericRelay((*udpSrcConn), dstConn, false)
-
-			return nil
-		}()
+func listenerInstance(lProtocol string, cProtocol string, udpSrcConn *net.Conn, listener *net.Listener, dst net.Conn) error {
+	fmt.Printf("Starting %s <-> %s proxy\n", lProtocol, cProtocol)
+	if lProtocol == "udp" {
+		go genericRelay((*udpSrcConn), dst, true)
 
 		return nil
 	}
@@ -311,29 +300,22 @@ func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr stri
 		fmt.Printf("Error: Failed to accept new connection: %v\n", err)
 		return err
 	}
-	fmt.Printf("Accepted a new connection\n")
 
-	// Connect to the target
-	dstConn, err := getDestConn(connectAddr)
-	if err != nil {
-		fmt.Printf("Error: Failed to connect to target: %v\n", err)
-		if lAddr.connType != "udp" {
-			srcConn.Close()
-		}
+	if lProtocol == "unix" && cProtocol == "unix" {
+		// Handle OOB if both src and dst are using unix sockets
+		go unixRelay(srcConn, dst)
 
-		return err
+		return nil
 	}
 
-	if cAddr.connType == "unix" && lAddr.connType == "unix" {
-		// Handle OOB if both src and dst are using unix sockets
-		go unixRelay(srcConn, dstConn)
-	} else {
-		go genericRelay(srcConn, dstConn, true)
-	}
+	go genericRelay(srcConn, dst, false)
 
 	return nil
 }
 
+var dstConnLock sync.Mutex
+var dstConn *net.Conn
+
 func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	// Only root should run this
 	if os.Geteuid() != 0 {
@@ -523,7 +505,21 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 				continue
 			}
 
-			err := listenerInstance(lAddr, cAddr, connectAddr, udpListener, listener)
+			dstConnLock.Lock()
+			if dstConn == nil {
+				// Connect to the target
+				tmp, err := getDestConn(connectAddr)
+				if err != nil {
+					fmt.Printf("Error: Failed to connect to target: %s\n", err)
+					dstConnLock.Unlock()
+					continue
+				}
+
+				dstConn = &tmp
+			}
+			dstConnLock.Unlock()
+
+			err := listenerInstance(lAddr.connType, cAddr.connType, udpListener, listener, *dstConn)
 			if err != nil {
 				fmt.Printf("Failed to prepare new listener instance: %s", err)
 			}
@@ -534,36 +530,112 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-func genericRelay(dst io.ReadWriteCloser, src io.ReadWriteCloser, closeDst bool) {
-	relayer := func(dst io.Writer, src io.Reader, ch chan error) {
-		_, err := io.Copy(eagain.Writer{Writer: dst}, eagain.Reader{Reader: src})
-		ch <- err
+func copyBuffer(dst io.Writer, src io.Reader) (written int64, dstErr error, srcErr error) {
+	// If the reader has a WriteTo method, use it to do the copy.
+	// Avoids an allocation and a copy.
+	if wt, ok := src.(io.WriterTo); ok {
+		written, dstErr = wt.WriteTo(dst)
+		return written, dstErr, nil
+	}
+
+	// Similarly, if the writer has a ReadFrom method, use it to do the copy.
+	if rt, ok := dst.(io.ReaderFrom); ok {
+		written, srcErr = rt.ReadFrom(src)
+		return written, srcErr, nil
+	}
+
+	size := 32 * 1024
+	if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N {
+		if l.N < 1 {
+			size = 1
+		} else {
+			size = int(l.N)
+		}
+	}
+
+	buf := make([]byte, size)
+	for {
+		nr, er := src.Read(buf)
+		if nr > 0 {
+			nw, ew := dst.Write(buf[0:nr])
+			if nw > 0 {
+				written += int64(nw)
+			}
+			if ew != nil {
+				dstErr = ew
+				break
+			}
+			if nr != nw {
+				dstErr = io.ErrShortWrite
+				break
+			}
+		}
+		if er != nil {
+			if er != io.EOF {
+				srcErr = er
+			}
+			break
+		}
+	}
+
+	return written, dstErr, srcErr
+}
+
+func genericRelay(src io.ReadWriteCloser, dst io.ReadWriteCloser, udp bool) {
+	relayer := func(src io.Writer, dst io.Reader, srcCh chan error, dstCh chan error, udp bool) {
+		var srcErr, dstErr error
+
+		if udp {
+			// EPOLLET behavior requires us to stop reading at
+			// EAGAIN so don't handle this error
+			_, srcErr, dstErr = copyBuffer(src, dst)
+		} else {
+			_, srcErr, dstErr = copyBuffer(eagain.Writer{Writer: src}, eagain.Reader{Reader: dst})
+		}
+		srcCh <- srcErr
+		dstCh <- dstErr
 	}
 
-	chSend := make(chan error)
-	go relayer(dst, src, chSend)
+	chSrcSend := make(chan error)
+	chDstSend := make(chan error)
+	go relayer(src, dst, chSrcSend, chDstSend, udp)
 
-	chRecv := make(chan error)
-	go relayer(src, dst, chRecv)
+	chSrcRecv := make(chan error)
+	chDstRecv := make(chan error)
+	go relayer(dst, src, chSrcRecv, chDstRecv, udp)
 
-	errSnd := <-chSend
-	errRcv := <-chRecv
+	errSrcSnd := <-chSrcSend
+	errDstSnd := <-chDstSend
+	errSrcRcv := <-chSrcRecv
+	errDstRcv := <-chDstRecv
 
-	src.Close()
-	if closeDst {
-		dst.Close()
+	if !udp {
+		src.Close()
 	}
 
-	if errSnd != nil {
-		fmt.Printf("Error while sending data %s\n", errSnd)
+	if chDstSend != nil || chDstRecv != nil {
+		dstConnLock.Lock()
+		dstConn = nil
+		dstConnLock.Unlock()
+		fmt.Println("Resetting target port")
 	}
 
-	if errRcv != nil {
-		fmt.Printf("Error while reading data %s\n", errRcv)
+	if errSrcSnd != nil {
+		fmt.Printf("Error: Sending data failed with %s\n", errSrcSnd)
+	}
+	if errDstSnd != nil {
+		fmt.Printf("Error: Sending data failed with %s\n", errDstSnd)
+	}
+
+	if errSrcRcv != nil {
+		fmt.Printf("Error: Reading data failed with %s\n", errSrcRcv)
+	}
+	if errDstRcv != nil {
+		fmt.Printf("Error: Reading data failed with %s\n", errDstRcv)
 	}
 }
 
-func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
+func unixRelayer(src *net.UnixConn, dst *net.UnixConn, srcCh chan error, dstCh chan error) {
 	dataBuf := make([]byte, 4096)
 	oobBuf := make([]byte, 4096)
 
@@ -577,7 +649,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 				goto readAgain
 			}
 			fmt.Printf("Disconnected during read: %v\n", err)
-			ch <- true
+			srcCh <- err
+			dstCh <- nil
 			return
 		}
 
@@ -586,7 +659,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 			entries, err := syscall.ParseSocketControlMessage(oobBuf[:sOob])
 			if err != nil {
 				fmt.Printf("Failed to parse control message: %v\n", err)
-				ch <- true
+				srcCh <- nil
+				dstCh <- nil
 				return
 			}
 
@@ -594,7 +668,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 				fds, err = syscall.ParseUnixRights(&msg)
 				if err != nil {
 					fmt.Printf("Failed to get fd list for control message: %v\n", err)
-					ch <- true
+					srcCh <- nil
+					dstCh <- nil
 					return
 				}
 			}
@@ -609,13 +684,15 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 				goto writeAgain
 			}
 			fmt.Printf("Disconnected during write: %v\n", err)
-			ch <- true
+			srcCh <- nil
+			dstCh <- err
 			return
 		}
 
 		if sData != tData || sOob != tOob {
 			fmt.Printf("Some data got lost during transfer, disconnecting.")
-			ch <- true
+			srcCh <- nil
+			dstCh <- nil
 			return
 		}
 
@@ -625,7 +702,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 				err := syscall.Close(fd)
 				if err != nil {
 					fmt.Printf("Failed to close fd %d: %v\n", fd, err)
-					ch <- true
+					srcCh <- nil
+					dstCh <- nil
 					return
 				}
 			}
@@ -634,17 +712,41 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) {
 }
 
 func unixRelay(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
-	chSend := make(chan bool)
-	go unixRelayer(dst.(*net.UnixConn), src.(*net.UnixConn), chSend)
+	chSrcSend := make(chan error)
+	chDstSend := make(chan error)
+	go unixRelayer(dst.(*net.UnixConn), src.(*net.UnixConn), chSrcSend, chDstSend)
+
+	chSrcRecv := make(chan error)
+	chDstRecv := make(chan error)
+	go unixRelayer(src.(*net.UnixConn), dst.(*net.UnixConn), chSrcRecv, chDstRecv)
 
-	chRecv := make(chan bool)
-	go unixRelayer(src.(*net.UnixConn), dst.(*net.UnixConn), chRecv)
+	errSrcSnd := <-chSrcSend
+	errDstSnd := <-chDstSend
+	errSrcRcv := <-chSrcRecv
+	errDstRcv := <-chDstRecv
 
-	<-chSend
-	<-chRecv
+	if chDstSend != nil || chDstRecv != nil {
+		dstConnLock.Lock()
+		dstConn = nil
+		dstConnLock.Unlock()
+		fmt.Println("Resetting target port")
+	}
 
 	src.Close()
-	dst.Close()
+
+	if errSrcSnd != nil {
+		fmt.Printf("Error: Sending data failed with %s\n", errSrcSnd)
+	}
+	if errDstSnd != nil {
+		fmt.Printf("Error: Sending data failed with %s\n", errDstSnd)
+	}
+
+	if errSrcRcv != nil {
+		fmt.Printf("Error: Reading data failed with %s\n", errSrcRcv)
+	}
+	if errDstRcv != nil {
+		fmt.Printf("Error: Reading data failed with %s\n", errDstRcv)
+	}
 }
 
 func tryListen(protocol string, addr string) (net.Listener, error) {


More information about the lxc-devel mailing list