191 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			191 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
| // Copyright 2012 Google Inc. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package main
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/http/httputil"
 | |
| 	"os"
 | |
| 	"os/exec"
 | |
| 	"os/signal"
 | |
| 	"sync"
 | |
| 	"syscall"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| type Proxy struct {
 | |
| 	BuildLabel         string
 | |
| 	MaxIdleDuration    time.Duration
 | |
| 	PollUpdateInterval time.Duration
 | |
| 
 | |
| 	ul        net.Listener
 | |
| 	httpAddr  string
 | |
| 	httpsAddr string
 | |
| }
 | |
| 
 | |
| func (p *Proxy) Run() error {
 | |
| 	hl, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("http listen failed: %v", err)
 | |
| 	}
 | |
| 	defer hl.Close()
 | |
| 
 | |
| 	hsl, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("https listen failed: %v", err)
 | |
| 	}
 | |
| 	defer hsl.Close()
 | |
| 
 | |
| 	p.ul, err = DefaultSocket.Listen()
 | |
| 	if err != nil {
 | |
| 		c, derr := DefaultSocket.Dial()
 | |
| 		if derr == nil {
 | |
| 			c.Close()
 | |
| 			fmt.Println("OK\nA proxy is already running... exiting")
 | |
| 			return nil
 | |
| 		} else if e, ok := derr.(*net.OpError); ok && e.Err == syscall.ECONNREFUSED {
 | |
| 			// Nothing is listening on the socket, unlink it and try again.
 | |
| 			syscall.Unlink(DefaultSocket.Path())
 | |
| 			p.ul, err = DefaultSocket.Listen()
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("unix listen failed on %v: %v", DefaultSocket.Path(), err)
 | |
| 		}
 | |
| 	}
 | |
| 	defer p.ul.Close()
 | |
| 	go p.closeOnSignal()
 | |
| 	go p.closeOnUpdate()
 | |
| 
 | |
| 	p.httpAddr = hl.Addr().String()
 | |
| 	p.httpsAddr = hsl.Addr().String()
 | |
| 	fmt.Printf("OK\nListening on unix socket=%v http=%v https=%v\n",
 | |
| 		p.ul.Addr(), p.httpAddr, p.httpsAddr)
 | |
| 
 | |
| 	result := make(chan error, 2)
 | |
| 	go p.serveUnix(result)
 | |
| 	go func() {
 | |
| 		result <- http.Serve(hl, &httputil.ReverseProxy{
 | |
| 			FlushInterval: 500 * time.Millisecond,
 | |
| 			Director:      func(r *http.Request) {},
 | |
| 		})
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		result <- http.Serve(hsl, &httputil.ReverseProxy{
 | |
| 			FlushInterval: 500 * time.Millisecond,
 | |
| 			Director: func(r *http.Request) {
 | |
| 				r.URL.Scheme = "https"
 | |
| 			},
 | |
| 		})
 | |
| 	}()
 | |
| 	return <-result
 | |
| }
 | |
| 
 | |
| type socketContext struct {
 | |
| 	sync.WaitGroup
 | |
| 	mutex sync.Mutex
 | |
| 	last  time.Time
 | |
| }
 | |
| 
 | |
| func (sc *socketContext) Done() {
 | |
| 	sc.mutex.Lock()
 | |
| 	defer sc.mutex.Unlock()
 | |
| 	sc.last = time.Now()
 | |
| 	sc.WaitGroup.Done()
 | |
| }
 | |
| 
 | |
| func (p *Proxy) serveUnix(result chan<- error) {
 | |
| 	sockCtx := &socketContext{}
 | |
| 	go p.closeOnIdle(sockCtx)
 | |
| 
 | |
| 	var err error
 | |
| 	for {
 | |
| 		var uconn net.Conn
 | |
| 		uconn, err = p.ul.Accept()
 | |
| 		if err != nil {
 | |
| 			err = fmt.Errorf("accept failed: %v", err)
 | |
| 			break
 | |
| 		}
 | |
| 		sockCtx.Add(1)
 | |
| 		go p.handleUnixConn(sockCtx, uconn)
 | |
| 	}
 | |
| 	sockCtx.Wait()
 | |
| 	result <- err
 | |
| }
 | |
| 
 | |
| func (p *Proxy) handleUnixConn(sockCtx *socketContext, uconn net.Conn) {
 | |
| 	defer sockCtx.Done()
 | |
| 	defer uconn.Close()
 | |
| 	data := []byte(fmt.Sprintf("%v\n%v", p.httpsAddr, p.httpAddr))
 | |
| 	uconn.SetDeadline(time.Now().Add(5 * time.Second))
 | |
| 	for i := 0; i < 2; i++ {
 | |
| 		if n, err := uconn.Write(data); err != nil {
 | |
| 			log.Printf("error sending http addresses: %+v\n", err)
 | |
| 			return
 | |
| 		} else if n != len(data) {
 | |
| 			log.Printf("sent %d data bytes, wanted %d\n", n, len(data))
 | |
| 			return
 | |
| 		}
 | |
| 		if _, err := uconn.Read([]byte{0, 0, 0, 0}); err != nil {
 | |
| 			log.Printf("error waiting for Ack: %+v\n", err)
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 	// Wait without a deadline for the client to finish via EOF
 | |
| 	uconn.SetDeadline(time.Time{})
 | |
| 	uconn.Read([]byte{0, 0, 0, 0})
 | |
| }
 | |
| 
 | |
| func (p *Proxy) closeOnIdle(sockCtx *socketContext) {
 | |
| 	for d := p.MaxIdleDuration; d > 0; {
 | |
| 		time.Sleep(d)
 | |
| 		sockCtx.Wait()
 | |
| 		sockCtx.mutex.Lock()
 | |
| 		if d = sockCtx.last.Add(p.MaxIdleDuration).Sub(time.Now()); d <= 0 {
 | |
| 			log.Println("graceful shutdown from idle timeout")
 | |
| 			p.ul.Close()
 | |
| 		}
 | |
| 		sockCtx.mutex.Unlock()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (p *Proxy) closeOnUpdate() {
 | |
| 	for {
 | |
| 		time.Sleep(p.PollUpdateInterval)
 | |
| 		if out, err := exec.Command(os.Args[0], "--print_label").Output(); err != nil {
 | |
| 			log.Printf("error polling for updated binary: %v\n", err)
 | |
| 		} else if s := string(out[:len(out)-1]); p.BuildLabel != s {
 | |
| 			log.Printf("graceful shutdown from updated binary: %q --> %q\n", p.BuildLabel, s)
 | |
| 			p.ul.Close()
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (p *Proxy) closeOnSignal() {
 | |
| 	ch := make(chan os.Signal, 10)
 | |
| 	signal.Notify(ch, os.Interrupt, os.Kill, os.Signal(syscall.SIGTERM), os.Signal(syscall.SIGHUP))
 | |
| 	sig := <-ch
 | |
| 	p.ul.Close()
 | |
| 	switch sig {
 | |
| 	case os.Signal(syscall.SIGHUP):
 | |
| 		log.Printf("graceful shutdown from signal: %v\n", sig)
 | |
| 	default:
 | |
| 		log.Fatalf("exiting from signal: %v\n", sig)
 | |
| 	}
 | |
| }
 |