tailfs: replace webdavfs with reverse proxies

Instead of modeling remote WebDAV servers as actual
webdav.FS instances, we now just proxy traffic to them.
This not only simplifies the code, but it also allows
WebDAV locking to work correctly by making sure locks are
handled by the servers that need to (i.e. the ones actually
serving the files).

Updates tailscale/corp#16827

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann
2024-02-21 06:40:12 -06:00
committed by Percy Wegmann
parent e1bd7488d0
commit 50fb8b9123
33 changed files with 1186 additions and 2008 deletions
@@ -0,0 +1,233 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package compositedav provides an http.Handler that composes multiple WebDAV
// services into a single WebDAV service that presents each of them as its own
// folder.
package compositedav
import (
"log"
"net/http"
"net/http/httputil"
"net/url"
"path"
"slices"
"strings"
"sync"
"github.com/tailscale/xnet/webdav"
"tailscale.com/tailfs/tailfsimpl/dirfs"
"tailscale.com/tailfs/tailfsimpl/shared"
"tailscale.com/tstime"
"tailscale.com/types/logger"
)
// Child is a child folder of this compositedav.
type Child struct {
*dirfs.Child
// BaseURL is the base URL of the WebDAV service to which we'll proxy
// requests for this Child. We will append the filename from the original
// URL to this.
BaseURL string
// Transport (if specified) is the http transport to use when communicating
// with this Child's WebDAV service.
Transport http.RoundTripper
rp *httputil.ReverseProxy
initOnce sync.Once
}
// CloseIdleConnections forcibly closes any idle connections on this Child's
// reverse proxy.
func (c *Child) CloseIdleConnections() {
tr, ok := c.Transport.(*http.Transport)
if ok {
tr.CloseIdleConnections()
}
}
func (c *Child) init() {
c.initOnce.Do(func() {
c.rp = &httputil.ReverseProxy{
Transport: c.Transport,
Rewrite: func(r *httputil.ProxyRequest) {},
}
})
}
// Handler implements http.Handler by using a dirfs.FS for showing a virtual
// read-only folder that represents the Child WebDAV services as sub-folders
// and proxying all requests for resources on the children to those children
// via httputil.ReverseProxy instances.
type Handler struct {
// Logf specifies a logging function to use.
Logf logger.Logf
// Clock, if specified, determines the current time. If not specified, we
// default to time.Now().
Clock tstime.Clock
// StatCache is an optional cache for PROPFIND results.
StatCache *StatCache
// childrenMu guards the fields below. Note that we do read the contents of
// children after releasing the read lock, which we can do because we never
// modify children but only ever replace it in SetChildren.
childrenMu sync.RWMutex
children []*Child
staticRoot string
}
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "PROPFIND" {
h.handlePROPFIND(w, r)
return
}
if r.Method != "GET" {
// If the user is performing a modification (e.g. PUT, MKDIR, etc),
// we need to invalidate the StatCache to make sure we're not knowingly
// showing stale stats.
// TODO(oxtoacart): maybe be more selective about invalidating cache
h.StatCache.invalidate()
}
mpl := h.maxPathLength(r)
pathComponents := shared.CleanAndSplit(r.URL.Path)
if len(pathComponents) >= mpl {
h.delegate(pathComponents[mpl-1:], w, r)
return
}
h.handle(w, r)
}
// handle handles the request locally using our dirfs.FS.
func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
h.childrenMu.RLock()
clk, kids, root := h.Clock, h.children, h.staticRoot
h.childrenMu.RUnlock()
children := make([]*dirfs.Child, 0, len(kids))
for _, child := range kids {
children = append(children, child.Child)
}
wh := &webdav.Handler{
LockSystem: webdav.NewMemLS(),
FileSystem: &dirfs.FS{
Clock: clk,
Children: children,
StaticRoot: root,
},
}
wh.ServeHTTP(w, r)
}
// delegate sends the request to the Child WebDAV server.
func (h *Handler) delegate(pathComponents []string, w http.ResponseWriter, r *http.Request) string {
childName := pathComponents[0]
child := h.GetChild(childName)
if child == nil {
w.WriteHeader(http.StatusNotFound)
return childName
}
u, err := url.Parse(child.BaseURL)
if err != nil {
h.logf("warning: parse base URL %s failed: %s", child.BaseURL, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return childName
}
u.Path = path.Join(u.Path, shared.Join(pathComponents[1:]...))
r.URL = u
r.Host = u.Host
child.rp.ServeHTTP(w, r)
return childName
}
// SetChildren replaces the entire existing set of children with the given
// ones. If staticRoot is given, the children will appear with a subfolder
// bearing named <staticRoot>.
func (h *Handler) SetChildren(staticRoot string, children ...*Child) {
for _, child := range children {
child.init()
}
slices.SortFunc(children, func(a, b *Child) int {
return strings.Compare(a.Name, b.Name)
})
h.childrenMu.Lock()
oldChildren := children
h.children = children
h.staticRoot = staticRoot
h.childrenMu.Unlock()
for _, child := range oldChildren {
child.CloseIdleConnections()
}
}
// GetChild gets the Child identified by name, or nil if no matching child
// found.
func (h *Handler) GetChild(name string) *Child {
h.childrenMu.RLock()
defer h.childrenMu.RUnlock()
_, child := h.findChildLocked(name)
return child
}
// Close closes this Handler,including closing all idle connections on children
// and stopping the StatCache (if caching is enabled).
func (h *Handler) Close() {
h.childrenMu.RLock()
oldChildren := h.children
h.children = nil
h.childrenMu.RUnlock()
for _, child := range oldChildren {
child.CloseIdleConnections()
}
if h.StatCache != nil {
h.StatCache.stop()
}
}
func (h *Handler) findChildLocked(name string) (int, *Child) {
var child *Child
i, found := slices.BinarySearchFunc(h.children, name, func(child *Child, name string) int {
return strings.Compare(child.Name, name)
})
if found {
return i, h.children[i]
}
return i, child
}
func (h *Handler) logf(format string, args ...any) {
if h.Logf != nil {
h.Logf(format, args...)
return
}
log.Printf(format, args...)
}
// maxPathLength calculates the maximum length of a path that can be handled by
// this handler without delegating to a Child. It's always at least 1, and if
// staticRoot is configured, it's 2.
func (h *Handler) maxPathLength(r *http.Request) int {
h.childrenMu.RLock()
defer h.childrenMu.RUnlock()
if h.staticRoot != "" {
return 2
}
return 1
}
@@ -0,0 +1,84 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package compositedav
import (
"bytes"
"fmt"
"math"
"net/http"
"regexp"
"tailscale.com/tailfs/tailfsimpl/shared"
)
var (
hrefRegex = regexp.MustCompile(`(?s)<D:href>/?([^<]*)/?</D:href>`)
)
func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) {
pathComponents := shared.CleanAndSplit(r.URL.Path)
mpl := h.maxPathLength(r)
if !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl {
// Delegate to a Child.
depth := getDepth(r)
cached := h.StatCache.get(r.URL.Path, depth)
if cached != nil {
w.Header().Del("Content-Length")
w.WriteHeader(http.StatusMultiStatus)
w.Write(cached)
return
}
// Use a buffering ResponseWriter so that we can manipulate the result.
// The only thing we use from the original ResponseWriter is Header().
bw := &bufferingResponseWriter{ResponseWriter: w}
mpl := h.maxPathLength(r)
h.delegate(pathComponents[mpl-1:], bw, r)
// Fixup paths to add the requested path as a prefix.
pathPrefix := shared.Join(pathComponents[0:mpl]...)
b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("<D:href>%s/$1</D:href>", pathPrefix)))
if h.StatCache != nil && bw.status == http.StatusMultiStatus && b != nil {
h.StatCache.set(r.URL.Path, depth, b)
}
w.Header().Del("Content-Length")
w.WriteHeader(bw.status)
w.Write(b)
return
}
h.handle(w, r)
}
func getDepth(r *http.Request) int {
switch r.Header.Get("Depth") {
case "0":
return 0
case "1":
return 1
case "infinity":
return math.MaxInt
}
return 0
}
type bufferingResponseWriter struct {
http.ResponseWriter
status int
buf bytes.Buffer
}
func (bw *bufferingResponseWriter) WriteHeader(statusCode int) {
bw.status = statusCode
}
func (bw *bufferingResponseWriter) Write(p []byte) (int, error) {
return bw.buf.Write(p)
}
@@ -0,0 +1,92 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package compositedav
import (
"sync"
"time"
"github.com/jellydator/ttlcache/v3"
)
// StatCache provides a cache for directory listings and file metadata.
// Especially when used from the command-line, mapped WebDAV drives can
// generate repetitive requests for the same file metadata. This cache helps
// reduce the number of round-trips to the WebDAV server for such requests.
// This is similar to the DirectoryCacheLifetime setting of Windows' built-in
// SMB client, see
// https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10)
type StatCache struct {
TTL time.Duration
// mu guards the below values.
mu sync.Mutex
cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte]
}
func (c *StatCache) get(name string, depth int) []byte {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
if c.cachesByDepthAndPath == nil {
return nil
}
cache := c.cachesByDepthAndPath[depth]
if cache == nil {
return nil
}
item := cache.Get(name)
if item == nil {
return nil
}
return item.Value()
}
func (c *StatCache) set(name string, depth int, value []byte) {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
if c.cachesByDepthAndPath == nil {
c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, []byte])
}
cache := c.cachesByDepthAndPath[depth]
if cache == nil {
cache = ttlcache.New(
ttlcache.WithTTL[string, []byte](c.TTL),
)
go cache.Start()
c.cachesByDepthAndPath[depth] = cache
}
cache.Set(name, value, ttlcache.DefaultTTL)
}
func (c *StatCache) invalidate() {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
for _, cache := range c.cachesByDepthAndPath {
cache.DeleteAll()
}
}
func (c *StatCache) stop() {
c.mu.Lock()
defer c.mu.Unlock()
for _, cache := range c.cachesByDepthAndPath {
cache.Stop()
}
}
@@ -0,0 +1,75 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package compositedav
import (
"bytes"
"testing"
"time"
"tailscale.com/tstest"
)
var (
val = []byte("1")
file = "file"
)
func TestStatCacheNoTimeout(t *testing.T) {
// Make sure we don't leak goroutines
tstest.ResourceCheck(t)
c := &StatCache{TTL: 5 * time.Second}
defer c.stop()
// check get before set
fetched := c.get(file, 1)
if fetched != nil {
t.Errorf("got %q, want nil", fetched)
}
// set new stat
c.set(file, 1, val)
fetched = c.get(file, 1)
if !bytes.Equal(fetched, val) {
t.Errorf("got %q, want %q", fetched, val)
}
// fetch stat again, should still be cached
fetched = c.get(file, 1)
if !bytes.Equal(fetched, val) {
t.Errorf("got %q, want %q", fetched, val)
}
}
func TestStatCacheTimeout(t *testing.T) {
// Make sure we don't leak goroutines
tstest.ResourceCheck(t)
c := &StatCache{TTL: 250 * time.Millisecond}
defer c.stop()
// set new stat
c.set(file, 1, val)
fetched := c.get(file, 1)
if !bytes.Equal(fetched, val) {
t.Errorf("got %q, want %q", fetched, val)
}
// wait for cache to expire and refetch stat, should be empty now
time.Sleep(c.TTL * 2)
fetched = c.get(file, 1)
if fetched != nil {
t.Errorf("invalidate should have cleared cached value")
}
c.set(file, 1, val)
// invalidate the cache and make sure nothing is returned
c.invalidate()
fetched = c.get(file, 1)
if fetched != nil {
t.Errorf("invalidate should have cleared cached value")
}
}