Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>main
parent
19c3e6cc9e
commit
5e9e11a77d
@ -0,0 +1,545 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package testcontrol contains a minimal control plane server for testing purposes.
|
||||
package testcontrol |
||||
|
||||
import ( |
||||
"bytes" |
||||
crand "crypto/rand" |
||||
"encoding/binary" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"log" |
||||
"math/rand" |
||||
"net/http" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/klauspost/compress/zstd" |
||||
"golang.org/x/crypto/nacl/box" |
||||
"inet.af/netaddr" |
||||
"tailscale.com/derp/derpmap" |
||||
"tailscale.com/smallzstd" |
||||
"tailscale.com/tailcfg" |
||||
"tailscale.com/types/logger" |
||||
"tailscale.com/types/wgkey" |
||||
) |
||||
|
||||
// Server is a control plane server. Its zero value is ready for use.
|
||||
// Everything is stored in-memory in one tailnet.
|
||||
type Server struct { |
||||
Logf logger.Logf // nil means to use the log package
|
||||
DERPMap *tailcfg.DERPMap // nil means to use prod DERP map
|
||||
|
||||
initMuxOnce sync.Once |
||||
mux *http.ServeMux |
||||
|
||||
mu sync.Mutex |
||||
pubKey wgkey.Key |
||||
privKey wgkey.Private |
||||
nodes map[tailcfg.NodeKey]*tailcfg.Node |
||||
users map[tailcfg.NodeKey]*tailcfg.User |
||||
logins map[tailcfg.NodeKey]*tailcfg.Login |
||||
updates map[tailcfg.NodeID]chan updateType |
||||
} |
||||
|
||||
func (s *Server) logf(format string, a ...interface{}) { |
||||
if s.Logf != nil { |
||||
s.Logf(format, a...) |
||||
} else { |
||||
log.Printf(format, a...) |
||||
} |
||||
} |
||||
|
||||
func (s *Server) initMux() { |
||||
s.mux = http.NewServeMux() |
||||
s.mux.HandleFunc("/", s.serveUnhandled) |
||||
s.mux.HandleFunc("/key", s.serveKey) |
||||
s.mux.HandleFunc("/machine/", s.serveMachine) |
||||
} |
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
||||
s.initMuxOnce.Do(s.initMux) |
||||
s.mux.ServeHTTP(w, r) |
||||
} |
||||
|
||||
func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) { |
||||
var got bytes.Buffer |
||||
r.Write(&got) |
||||
go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes())) |
||||
} |
||||
|
||||
func (s *Server) publicKey() wgkey.Key { |
||||
pub, _ := s.keyPair() |
||||
return pub |
||||
} |
||||
|
||||
func (s *Server) privateKey() wgkey.Private { |
||||
_, priv := s.keyPair() |
||||
return priv |
||||
} |
||||
|
||||
func (s *Server) keyPair() (pub wgkey.Key, priv wgkey.Private) { |
||||
s.mu.Lock() |
||||
defer s.mu.Unlock() |
||||
if s.pubKey.IsZero() { |
||||
var err error |
||||
s.privKey, err = wgkey.NewPrivate() |
||||
if err != nil { |
||||
go panic(err) // bring down test, even if in http.Handler
|
||||
} |
||||
s.pubKey = s.privKey.Public() |
||||
} |
||||
return s.pubKey, s.privKey |
||||
} |
||||
|
||||
func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) { |
||||
w.Header().Set("Content-Type", "text/plain") |
||||
w.WriteHeader(200) |
||||
io.WriteString(w, s.publicKey().HexString()) |
||||
} |
||||
|
||||
func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) { |
||||
mkeyStr := strings.TrimPrefix(r.URL.Path, "/machine/") |
||||
rem := "" |
||||
if i := strings.IndexByte(mkeyStr, '/'); i != -1 { |
||||
rem = mkeyStr[i:] |
||||
mkeyStr = mkeyStr[:i] |
||||
} |
||||
|
||||
key, err := wgkey.ParseHex(mkeyStr) |
||||
if err != nil { |
||||
http.Error(w, "bad machine key hex", 400) |
||||
return |
||||
} |
||||
mkey := tailcfg.MachineKey(key) |
||||
|
||||
if r.Method != "POST" { |
||||
http.Error(w, "POST required", 400) |
||||
return |
||||
} |
||||
|
||||
switch rem { |
||||
case "": |
||||
s.serveRegister(w, r, mkey) |
||||
case "/map": |
||||
s.serveMap(w, r, mkey) |
||||
default: |
||||
s.serveUnhandled(w, r) |
||||
} |
||||
} |
||||
|
||||
// Node returns the node for nodeKey. It's always nil or cloned memory.
|
||||
func (s *Server) Node(nodeKey tailcfg.NodeKey) *tailcfg.Node { |
||||
s.mu.Lock() |
||||
defer s.mu.Unlock() |
||||
return s.nodes[nodeKey].Clone() |
||||
} |
||||
|
||||
func (s *Server) getUser(nodeKey tailcfg.NodeKey) (*tailcfg.User, *tailcfg.Login) { |
||||
s.mu.Lock() |
||||
defer s.mu.Unlock() |
||||
if s.users == nil { |
||||
s.users = map[tailcfg.NodeKey]*tailcfg.User{} |
||||
} |
||||
if s.logins == nil { |
||||
s.logins = map[tailcfg.NodeKey]*tailcfg.Login{} |
||||
} |
||||
if u, ok := s.users[nodeKey]; ok { |
||||
return u, s.logins[nodeKey] |
||||
} |
||||
id := tailcfg.UserID(len(s.users) + 1) |
||||
domain := "fake-control.example.net" |
||||
loginName := fmt.Sprintf("user-%d@%s", id, domain) |
||||
displayName := fmt.Sprintf("User %d", id) |
||||
login := &tailcfg.Login{ |
||||
ID: tailcfg.LoginID(id), |
||||
Provider: "testcontrol", |
||||
LoginName: loginName, |
||||
DisplayName: displayName, |
||||
ProfilePicURL: "https://tailscale.com/static/images/marketing/team-carney.jpg", |
||||
Domain: domain, |
||||
} |
||||
user := &tailcfg.User{ |
||||
ID: id, |
||||
LoginName: loginName, |
||||
DisplayName: displayName, |
||||
Domain: domain, |
||||
Logins: []tailcfg.LoginID{login.ID}, |
||||
} |
||||
s.users[nodeKey] = user |
||||
s.logins[nodeKey] = login |
||||
return user, login |
||||
} |
||||
|
||||
func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { |
||||
var req tailcfg.RegisterRequest |
||||
if err := s.decode(mkey, r.Body, &req); err != nil { |
||||
panic(fmt.Sprintf("serveRegister: decode: %v", err)) |
||||
} |
||||
if req.Version != 1 { |
||||
panic(fmt.Sprintf("serveRegister: unsupported version: %d", req.Version)) |
||||
} |
||||
if req.NodeKey.IsZero() { |
||||
panic("serveRegister: request has zero node key") |
||||
} |
||||
|
||||
user, login := s.getUser(req.NodeKey) |
||||
s.mu.Lock() |
||||
if s.nodes == nil { |
||||
s.nodes = map[tailcfg.NodeKey]*tailcfg.Node{} |
||||
} |
||||
s.nodes[req.NodeKey] = &tailcfg.Node{ |
||||
ID: tailcfg.NodeID(user.ID), |
||||
StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), |
||||
User: user.ID, |
||||
Machine: mkey, |
||||
Key: req.NodeKey, |
||||
MachineAuthorized: true, |
||||
} |
||||
s.mu.Unlock() |
||||
|
||||
res, err := s.encode(mkey, false, tailcfg.RegisterResponse{ |
||||
User: *user, |
||||
Login: *login, |
||||
NodeKeyExpired: false, |
||||
MachineAuthorized: true, |
||||
AuthURL: "", // all good; TODO(bradfitz): add ways to not start all good.
|
||||
}) |
||||
if err != nil { |
||||
go panic(fmt.Sprintf("serveRegister: encode: %v", err)) |
||||
} |
||||
w.WriteHeader(200) |
||||
w.Write(res) |
||||
} |
||||
|
||||
// updateType indicates why a long-polling map request is being woken
|
||||
// up for an update.
|
||||
type updateType int |
||||
|
||||
const ( |
||||
// updatePeerChanged is an update that a peer has changed.
|
||||
updatePeerChanged updateType = iota + 1 |
||||
|
||||
// updateSelfChanged is an update that the node changed itself
|
||||
// via a lite endpoint update. These ones are never dup-suppressed,
|
||||
// as the client is expecting an answer regardless.
|
||||
updateSelfChanged |
||||
) |
||||
|
||||
func (s *Server) updateLocked(source string, peers []tailcfg.NodeID) { |
||||
for _, peer := range peers { |
||||
sendUpdate(s.updates[peer], updatePeerChanged) |
||||
} |
||||
} |
||||
|
||||
// sendUpdate sends updateType to dst if dst is non-nil and
|
||||
// has capacity.
|
||||
func sendUpdate(dst chan<- updateType, updateType updateType) { |
||||
if dst == nil { |
||||
return |
||||
} |
||||
// The dst channel has a buffer size of 1.
|
||||
// If we fail to insert an update into the buffer that
|
||||
// means there is already an update pending.
|
||||
select { |
||||
case dst <- updateType: |
||||
default: |
||||
} |
||||
} |
||||
|
||||
func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { |
||||
ctx := r.Context() |
||||
|
||||
req := new(tailcfg.MapRequest) |
||||
if err := s.decode(mkey, r.Body, req); err != nil { |
||||
go panic(fmt.Sprintf("bad map request: %v", err)) |
||||
} |
||||
|
||||
jitter := time.Duration(rand.Intn(8000)) * time.Millisecond |
||||
keepAlive := 50*time.Second + jitter |
||||
|
||||
node := s.Node(req.NodeKey) |
||||
if node == nil { |
||||
http.Error(w, "node not found", 400) |
||||
return |
||||
} |
||||
if node.Machine != mkey { |
||||
http.Error(w, "node doesn't match machine key", 400) |
||||
return |
||||
} |
||||
|
||||
var peersToUpdate []tailcfg.NodeID |
||||
if !req.ReadOnly { |
||||
endpoints := filterInvalidIPv6Endpoints(req.Endpoints) |
||||
node.Endpoints = endpoints |
||||
// TODO: more
|
||||
// TODO: register node,
|
||||
//s.UpdateEndpoint(mkey, req.NodeKey,
|
||||
// XXX
|
||||
} |
||||
|
||||
nodeID := node.ID |
||||
|
||||
s.mu.Lock() |
||||
updatesCh := make(chan updateType, 1) |
||||
oldUpdatesCh := s.updates[nodeID] |
||||
if breakSameNodeMapResponseStreams(req) { |
||||
if oldUpdatesCh != nil { |
||||
close(oldUpdatesCh) |
||||
} |
||||
if s.updates == nil { |
||||
s.updates = map[tailcfg.NodeID]chan updateType{} |
||||
} |
||||
s.updates[nodeID] = updatesCh |
||||
} else { |
||||
sendUpdate(oldUpdatesCh, updateSelfChanged) |
||||
} |
||||
s.updateLocked("serveMap", peersToUpdate) |
||||
s.mu.Unlock() |
||||
|
||||
// ReadOnly implies no streaming, as it doesn't
|
||||
// register an updatesCh to get updates.
|
||||
streaming := req.Stream && !req.ReadOnly |
||||
compress := req.Compress != "" |
||||
|
||||
w.WriteHeader(200) |
||||
for { |
||||
res, err := s.MapResponse(req) |
||||
if err != nil { |
||||
// TODO: log
|
||||
return |
||||
} |
||||
if res == nil { |
||||
return // done
|
||||
} |
||||
// TODO: add minner if/when needed
|
||||
resBytes, err := json.Marshal(res) |
||||
if err != nil { |
||||
s.logf("json.Marshal: %v", err) |
||||
return |
||||
} |
||||
if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { |
||||
return |
||||
} |
||||
if !streaming { |
||||
return |
||||
} |
||||
keepAliveLoop: |
||||
for { |
||||
var keepAliveTimer *time.Timer |
||||
var keepAliveTimerCh <-chan time.Time |
||||
if keepAlive > 0 { |
||||
keepAliveTimer = time.NewTimer(keepAlive) |
||||
keepAliveTimerCh = keepAliveTimer.C |
||||
} |
||||
select { |
||||
case <-ctx.Done(): |
||||
if keepAliveTimer != nil { |
||||
keepAliveTimer.Stop() |
||||
} |
||||
return |
||||
case _, ok := <-updatesCh: |
||||
if !ok { |
||||
// replaced by new poll request
|
||||
return |
||||
} |
||||
break keepAliveLoop |
||||
case <-keepAliveTimerCh: |
||||
if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
var keepAliveMsg = &struct { |
||||
KeepAlive bool |
||||
}{ |
||||
KeepAlive: true, |
||||
} |
||||
|
||||
var prodDERPMap = derpmap.Prod() |
||||
|
||||
// MapResponse generates a MapResponse for a MapRequest.
|
||||
//
|
||||
// No updates to s are done here.
|
||||
func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, err error) { |
||||
node := s.Node(req.NodeKey) |
||||
if node == nil { |
||||
// node key rotated away (once test server supports that)
|
||||
return nil, nil |
||||
} |
||||
derpMap := s.DERPMap |
||||
if derpMap == nil { |
||||
derpMap = prodDERPMap |
||||
} |
||||
user, _ := s.getUser(req.NodeKey) |
||||
res = &tailcfg.MapResponse{ |
||||
Node: node, |
||||
DERPMap: derpMap, |
||||
Domain: string(user.Domain), |
||||
CollectServices: "true", |
||||
PacketFilter: tailcfg.FilterAllowAll, |
||||
} |
||||
res.Node.Addresses = []netaddr.IPPrefix{ |
||||
netaddr.MustParseIPPrefix(fmt.Sprintf("100.64.%d.%d/32", uint8(node.ID>>8), uint8(node.ID))), |
||||
} |
||||
res.Node.AllowedIPs = res.Node.Addresses |
||||
return res, nil |
||||
} |
||||
|
||||
func (s *Server) sendMapMsg(w http.ResponseWriter, mkey tailcfg.MachineKey, compress bool, msg interface{}) error { |
||||
resBytes, err := s.encode(mkey, compress, msg) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if len(resBytes) > 16<<20 { |
||||
return fmt.Errorf("map message too big: %d", len(resBytes)) |
||||
} |
||||
var siz [4]byte |
||||
binary.LittleEndian.PutUint32(siz[:], uint32(len(resBytes))) |
||||
if _, err := w.Write(siz[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(resBytes); err != nil { |
||||
return err |
||||
} |
||||
if f, ok := w.(http.Flusher); ok { |
||||
f.Flush() |
||||
} else { |
||||
s.logf("[unexpected] ResponseWriter %T is not a Flusher", w) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (s *Server) decode(mkey tailcfg.MachineKey, r io.Reader, v interface{}) error { |
||||
if c, _ := r.(io.Closer); c != nil { |
||||
defer c.Close() |
||||
} |
||||
const msgLimit = 1 << 20 |
||||
msg, err := ioutil.ReadAll(io.LimitReader(r, msgLimit)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if len(msg) == msgLimit { |
||||
return errors.New("encrypted message too long") |
||||
} |
||||
|
||||
var nonce [24]byte |
||||
if len(msg) < len(nonce)+1 { |
||||
return errors.New("missing nonce") |
||||
} |
||||
copy(nonce[:], msg) |
||||
msg = msg[len(nonce):] |
||||
|
||||
priv := s.privateKey() |
||||
pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv) |
||||
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) |
||||
if !ok { |
||||
return errors.New("can't decrypt request") |
||||
} |
||||
return json.Unmarshal(decrypted, v) |
||||
} |
||||
|
||||
var zstdEncoderPool = &sync.Pool{ |
||||
New: func() interface{} { |
||||
encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return encoder |
||||
}, |
||||
} |
||||
|
||||
func (s *Server) encode(mkey tailcfg.MachineKey, compress bool, v interface{}) (b []byte, err error) { |
||||
var isBytes bool |
||||
if b, isBytes = v.([]byte); !isBytes { |
||||
b, err = json.Marshal(v) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
if compress { |
||||
encoder := zstdEncoderPool.Get().(*zstd.Encoder) |
||||
b = encoder.EncodeAll(b, nil) |
||||
encoder.Close() |
||||
zstdEncoderPool.Put(encoder) |
||||
} |
||||
var nonce [24]byte |
||||
if _, err := io.ReadFull(crand.Reader, nonce[:]); err != nil { |
||||
panic(err) |
||||
} |
||||
priv := s.privateKey() |
||||
pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv) |
||||
msgData := box.Seal(nonce[:], b, &nonce, pub, pri) |
||||
return msgData, nil |
||||
} |
||||
|
||||
// filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps,
|
||||
// modify the slice in place, returning the potentially smaller subset (aliasing
|
||||
// the original memory).
|
||||
//
|
||||
// Two types of IPv6 endpoints are considered invalid: link-local
|
||||
// addresses, and anything with a zone.
|
||||
func filterInvalidIPv6Endpoints(eps []string) []string { |
||||
clean := eps[:0] |
||||
for _, ep := range eps { |
||||
if keepClientEndpoint(ep) { |
||||
clean = append(clean, ep) |
||||
} |
||||
} |
||||
return clean |
||||
} |
||||
|
||||
func keepClientEndpoint(ep string) bool { |
||||
ipp, err := netaddr.ParseIPPort(ep) |
||||
if err != nil { |
||||
// Shouldn't have made it this far if we unmarshalled
|
||||
// the incoming JSON response.
|
||||
return false |
||||
} |
||||
ip := ipp.IP |
||||
if ip.Zone() != "" { |
||||
return false |
||||
} |
||||
if ip.Is6() && ip.IsLinkLocalUnicast() { |
||||
// We let clients send these for now, but
|
||||
// tailscaled doesn't know how to use them yet
|
||||
// so we filter them out for now. A future
|
||||
// MapRequest.Version might signal that
|
||||
// clients know how to use them (e.g. try all
|
||||
// local scopes).
|
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// breakSameNodeMapResponseStreams reports whether req should break a
|
||||
// prior long-polling MapResponse stream (if active) from the same
|
||||
// node ID.
|
||||
func breakSameNodeMapResponseStreams(req *tailcfg.MapRequest) bool { |
||||
if req.ReadOnly { |
||||
// Don't register our updatesCh for closability
|
||||
// nor close another peer's if we're a read-only request.
|
||||
return false |
||||
} |
||||
if !req.Stream && req.OmitPeers { |
||||
// Likewise, if we're not streaming and not asking for peers,
|
||||
// (but still mutable, without Readonly set), consider this an endpoint
|
||||
// update request only, and don't close any existing map response
|
||||
// for this nodeID. It's likely the same client with a built-up
|
||||
// compression context. We want to let them update their
|
||||
// new endpoints with us without breaking that other long-running
|
||||
// map response.
|
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
Loading…
Reference in new issue