Fixes #4395 Fixes #2605 Signed-off-by: James Tucker <james@tailscale.com>main
parent
90b5f6286c
commit
700bd37730
@ -0,0 +1,26 @@ |
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package tshttpproxy |
||||
|
||||
import ( |
||||
"net/http" |
||||
"net/url" |
||||
|
||||
"tailscale.com/version/distro" |
||||
) |
||||
|
||||
func init() { |
||||
sysProxyFromEnv = linuxSysProxyFromEnv |
||||
} |
||||
|
||||
func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { |
||||
if distro.Get() == distro.Synology { |
||||
return synologyProxyFromConfigCached(req) |
||||
} |
||||
return nil, nil |
||||
} |
||||
@ -0,0 +1,132 @@ |
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package tshttpproxy |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
"os" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"tailscale.com/util/lineread" |
||||
) |
||||
|
||||
// These vars are overridden for tests.
|
||||
var ( |
||||
synologyProxyConfigPath = "/etc/proxy.conf" |
||||
|
||||
openSynologyProxyConf = func() (io.ReadCloser, error) { |
||||
return os.Open(synologyProxyConfigPath) |
||||
} |
||||
) |
||||
|
||||
var cache struct { |
||||
sync.Mutex |
||||
proxy *url.URL |
||||
updated time.Time |
||||
} |
||||
|
||||
func synologyProxyFromConfigCached(req *http.Request) (*url.URL, error) { |
||||
if req.URL == nil { |
||||
return nil, nil |
||||
} |
||||
|
||||
cache.Lock() |
||||
defer cache.Unlock() |
||||
|
||||
modtime := mtime(synologyProxyConfigPath) |
||||
|
||||
if cache.updated == modtime { |
||||
return cache.proxy, nil |
||||
} |
||||
|
||||
val, err := synologyProxyFromConfig(req) |
||||
cache.proxy = val |
||||
|
||||
cache.updated = modtime |
||||
|
||||
return val, err |
||||
} |
||||
|
||||
func synologyProxyFromConfig(req *http.Request) (*url.URL, error) { |
||||
r, err := openSynologyProxyConf() |
||||
if err != nil { |
||||
if os.IsNotExist(err) { |
||||
return nil, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
defer r.Close() |
||||
|
||||
return parseSynologyConfig(r) |
||||
} |
||||
|
||||
func parseSynologyConfig(r io.Reader) (*url.URL, error) { |
||||
cfg := map[string]string{} |
||||
|
||||
if err := lineread.Reader(r, func(line []byte) error { |
||||
// accept and skip over empty lines
|
||||
line = bytes.TrimSpace(line) |
||||
if len(line) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
key, value, ok := strings.Cut(string(line), "=") |
||||
if !ok { |
||||
return fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) |
||||
} |
||||
cfg[string(key)] = string(value) |
||||
return nil |
||||
}); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if cfg["proxy_enabled"] != "yes" { |
||||
return nil, nil |
||||
} |
||||
|
||||
proxyURL := new(url.URL) |
||||
if cfg["auth_enabled"] == "yes" { |
||||
proxyURL.User = url.UserPassword(cfg["proxy_user"], cfg["proxy_pwd"]) |
||||
} |
||||
|
||||
proxyURL.Scheme = "https" |
||||
host, port := cfg["https_host"], cfg["https_port"] |
||||
if host == "" { |
||||
proxyURL.Scheme = "http" |
||||
host, port = cfg["http_host"], cfg["http_port"] |
||||
} |
||||
|
||||
if host == "" { |
||||
return nil, nil |
||||
} |
||||
|
||||
if port != "" { |
||||
proxyURL.Host = net.JoinHostPort(host, port) |
||||
} else { |
||||
proxyURL.Host = host |
||||
} |
||||
|
||||
return proxyURL, nil |
||||
} |
||||
|
||||
// mtime stat's path and returns it's modification time. If path does not exist,
|
||||
// it returns the unix epoch.
|
||||
func mtime(path string) time.Time { |
||||
fi, err := os.Stat(path) |
||||
if err != nil { |
||||
return time.Unix(0, 0) |
||||
} |
||||
return fi.ModTime() |
||||
} |
||||
@ -0,0 +1,288 @@ |
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package tshttpproxy |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"net/url" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestSynologyProxyFromConfigCached(t *testing.T) { |
||||
req, err := http.NewRequest("GET", "https://example.org/", nil) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
var orig string |
||||
orig, synologyProxyConfigPath = synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf") |
||||
defer func() { synologyProxyConfigPath = orig }() |
||||
|
||||
t.Run("no config file", func(t *testing.T) { |
||||
if _, err := os.Stat(synologyProxyConfigPath); err == nil { |
||||
t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) |
||||
} |
||||
|
||||
cache.updated = time.Time{} |
||||
cache.proxy = nil |
||||
|
||||
if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { |
||||
t.Fatalf("got %s, %v; want nil, nil", val, err) |
||||
} |
||||
|
||||
if got, want := cache.updated, time.Unix(0, 0); got != want { |
||||
t.Fatalf("got %s, want %s", got, want) |
||||
} |
||||
if cache.proxy != nil { |
||||
t.Fatalf("got %s, want nil", cache.proxy) |
||||
} |
||||
}) |
||||
|
||||
t.Run("config file updated", func(t *testing.T) { |
||||
cache.updated = time.Now() |
||||
cache.proxy = nil |
||||
|
||||
if err := ioutil.WriteFile(synologyProxyConfigPath, []byte(` |
||||
proxy_enabled=yes |
||||
http_host=10.0.0.55 |
||||
http_port=80 |
||||
`), 0600); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
val, err := synologyProxyFromConfigCached(req) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { |
||||
t.Fatalf("got %s; want %s", val, want) |
||||
} |
||||
}) |
||||
|
||||
t.Run("config file removed", func(t *testing.T) { |
||||
cache.updated = time.Now() |
||||
cache.proxy = urlMustParse("http://127.0.0.1/") |
||||
|
||||
if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
val, err := synologyProxyFromConfigCached(req) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if val != nil { |
||||
t.Fatalf("got %s; want nil", val) |
||||
} |
||||
if cache.proxy != nil { |
||||
t.Fatalf("got %s, want nil", cache.proxy) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func TestSynologyProxyFromConfig(t *testing.T) { |
||||
var ( |
||||
openReader io.ReadCloser |
||||
openErr error |
||||
) |
||||
var origOpen func() (io.ReadCloser, error) |
||||
origOpen, openSynologyProxyConf = openSynologyProxyConf, func() (io.ReadCloser, error) { |
||||
return openReader, openErr |
||||
} |
||||
defer func() { openSynologyProxyConf = origOpen }() |
||||
|
||||
req, err := http.NewRequest("GET", "https://example.com/", nil) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
t.Run("with config", func(t *testing.T) { |
||||
mc := &mustCloser{Reader: strings.NewReader(` |
||||
proxy_user=foo |
||||
proxy_pwd=bar |
||||
proxy_enabled=yes |
||||
adv_enabled=yes |
||||
bypass_enabled=yes |
||||
auth_enabled=yes |
||||
https_host=10.0.0.66 |
||||
https_port=8443 |
||||
http_host=10.0.0.55 |
||||
http_port=80 |
||||
`)} |
||||
defer mc.check(t) |
||||
openReader = mc |
||||
|
||||
proxyURL, err := synologyProxyFromConfig(req) |
||||
|
||||
if got, want := err, openErr; got != want { |
||||
t.Fatalf("got %s, want %s", got, want) |
||||
} |
||||
|
||||
if got, want := proxyURL, urlMustParse("https://foo:bar@10.0.0.66:8443"); got.String() != want.String() { |
||||
t.Fatalf("got %s, want %s", got, want) |
||||
} |
||||
|
||||
}) |
||||
|
||||
t.Run("non-existent config", func(t *testing.T) { |
||||
openReader = nil |
||||
openErr = os.ErrNotExist |
||||
|
||||
proxyURL, err := synologyProxyFromConfig(req) |
||||
if err != nil { |
||||
t.Fatalf("expected no error, got %s", err) |
||||
} |
||||
if proxyURL != nil { |
||||
t.Fatalf("expected no url, got %s", proxyURL) |
||||
} |
||||
}) |
||||
|
||||
t.Run("error opening config", func(t *testing.T) { |
||||
openReader = nil |
||||
openErr = errors.New("example error") |
||||
|
||||
proxyURL, err := synologyProxyFromConfig(req) |
||||
if err != openErr { |
||||
t.Fatalf("expected %s, got %s", openErr, err) |
||||
} |
||||
if proxyURL != nil { |
||||
t.Fatalf("expected no url, got %s", proxyURL) |
||||
} |
||||
}) |
||||
|
||||
} |
||||
|
||||
func TestParseSynologyConfig(t *testing.T) { |
||||
cases := map[string]struct { |
||||
input string |
||||
url *url.URL |
||||
err error |
||||
}{ |
||||
"populated": { |
||||
input: ` |
||||
proxy_user=foo |
||||
proxy_pwd=bar |
||||
proxy_enabled=yes |
||||
adv_enabled=yes |
||||
bypass_enabled=yes |
||||
auth_enabled=yes |
||||
https_host=10.0.0.66 |
||||
https_port=8443 |
||||
http_host=10.0.0.55 |
||||
http_port=80 |
||||
`, |
||||
url: urlMustParse("https://foo:bar@10.0.0.66:8443"), |
||||
err: nil, |
||||
}, |
||||
"no-auth": { |
||||
input: ` |
||||
proxy_user=foo |
||||
proxy_pwd=bar |
||||
proxy_enabled=yes |
||||
adv_enabled=yes |
||||
bypass_enabled=yes |
||||
auth_enabled=no |
||||
https_host=10.0.0.66 |
||||
https_port=8443 |
||||
http_host=10.0.0.55 |
||||
http_port=80 |
||||
`, |
||||
url: urlMustParse("https://10.0.0.66:8443"), |
||||
err: nil, |
||||
}, |
||||
"http": { |
||||
input: ` |
||||
proxy_user=foo |
||||
proxy_pwd=bar |
||||
proxy_enabled=yes |
||||
adv_enabled=yes |
||||
bypass_enabled=yes |
||||
auth_enabled=yes |
||||
https_host= |
||||
https_port=8443 |
||||
http_host=10.0.0.55 |
||||
http_port=80 |
||||
`, |
||||
url: urlMustParse("http://foo:bar@10.0.0.55:80"), |
||||
err: nil, |
||||
}, |
||||
"empty": { |
||||
input: ` |
||||
proxy_user= |
||||
proxy_pwd= |
||||
proxy_enabled= |
||||
adv_enabled= |
||||
bypass_enabled= |
||||
auth_enabled= |
||||
https_host= |
||||
https_port= |
||||
http_host= |
||||
http_port= |
||||
`, |
||||
url: nil, |
||||
err: nil, |
||||
}, |
||||
} |
||||
|
||||
for name, example := range cases { |
||||
t.Run(name, func(t *testing.T) { |
||||
url, err := parseSynologyConfig(strings.NewReader(example.input)) |
||||
if err != example.err { |
||||
t.Fatal(err) |
||||
} |
||||
if example.err != nil { |
||||
return |
||||
} |
||||
|
||||
if url == nil && example.url == nil { |
||||
return |
||||
} |
||||
|
||||
if example.url == nil { |
||||
if url != nil { |
||||
t.Fatalf("got %s, want nil", url) |
||||
} |
||||
} |
||||
|
||||
if got, want := example.url.String(), url.String(); got != want { |
||||
t.Fatalf("got %s, want %s", got, want) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
func urlMustParse(u string) *url.URL { |
||||
r, err := url.Parse(u) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("urlMustParse: %s", err)) |
||||
} |
||||
return r |
||||
} |
||||
|
||||
type mustCloser struct { |
||||
io.Reader |
||||
closed bool |
||||
} |
||||
|
||||
func (m *mustCloser) Close() error { |
||||
m.closed = true |
||||
return nil |
||||
} |
||||
|
||||
func (m *mustCloser) check(t *testing.T) { |
||||
if !m.closed { |
||||
t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue