TOC

Golang 实现简单反向代理

看到微信公众号 Go语言教程 的文章《golang 实现简单网关》才知道还有 httputil.ReverseProxy 这么个东西。
PS:这玩意儿有什么必要放在标准库?

还挺有趣,在作者示例的基础之上完善了一下,实现多服务,多后端节点的一个负载均衡(还可以再补上权重)。

package main

import (
    "bytes"
    "fmt"
    "io/ioutil"
    "log"
    "math/rand"
    "net/http"
    "net/http/httputil"
    "net/url"
    "strings"
    "time"
)

func main() {
    addr := "127.0.0.1:2002"
    backends := map[string][]string{
        "service1": {"http://127.0.0.1:2003", "http://127.0.0.1:2004"},
    }
    reversePorxy := NewReverseProxy(backends)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, reversePorxy))
}

func reqConvert(req *http.Request, target *url.URL) {
    req.URL.Scheme = target.Scheme
    req.URL.Host = target.Host
    req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
    if target.RawQuery == "" || req.URL.RawQuery == "" {
        req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
    } else {
        req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
    }
    if _, ok := req.Header["User-Agent"]; !ok {
        req.Header.Set("User-Agent", "")
    }
    req.Header.Set("X-Real-Ip", strings.Split(req.RemoteAddr, ":")[0])
}

func NewReverseProxy(backends map[string][]string) *httputil.ReverseProxy {
    var targets = make(map[string][]*url.URL)
    for srv, nodes := range backends {
        for _, nodeUrl := range nodes {
            target, _ := url.Parse(nodeUrl)
            targets[srv] = append(targets[srv], target)
        }
    }

    director := func(req *http.Request) {
        segments := strings.SplitN(req.URL.Path, "/", 3)
        if len(segments) != 3 {
            return
        }
        srv := segments[1]
        req.URL.Path = segments[2]
        if _, ok := targets[srv]; !ok {
            log.Printf("unknown path: %s", req.URL.Path)
            return
        }
        rand.Seed(time.Now().UnixNano())
        randomIndex := rand.Intn(len(targets[srv]))
        target := targets[srv][randomIndex]
        reqConvert(req, target)
    }

    modifyFunc := func(res *http.Response) error {
        if res.StatusCode != http.StatusOK {
            oldPayLoad, err := ioutil.ReadAll(res.Body)
            if err != nil {
                return err
            }
            newPayLoad := []byte("hello " + string(oldPayLoad))
            res.Body = ioutil.NopCloser(bytes.NewBuffer(newPayLoad))
            res.ContentLength = int64(len(newPayLoad))
            res.Header.Set("Content-Length", fmt.Sprint(len(newPayLoad)))
        }
        return nil
    }
    return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyFunc}
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash:
        return a + "/" + b
    }
    return a + b
}