Signed-off-by: David Anderson <danderson@tailscale.com>main
parent
3e1daab704
commit
da7544bcc5
@ -0,0 +1,330 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package noise implements the base transport of the Tailscale 2021
|
||||
// control protocol.
|
||||
//
|
||||
// The base transport implements Noise IK, instantiated with
|
||||
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
|
||||
package noise |
||||
|
||||
import ( |
||||
"crypto/cipher" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
|
||||
"golang.org/x/crypto/blake2s" |
||||
chp "golang.org/x/crypto/chacha20poly1305" |
||||
"golang.org/x/crypto/poly1305" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
const ( |
||||
maxPlaintextSize = 4096 |
||||
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize |
||||
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
|
||||
) |
||||
|
||||
// A Conn is a secured Noise connection. It implements the net.Conn
|
||||
// interface, with the unusual trait that any write error (including a
|
||||
// SetWriteDeadline induced i/o timeout) cause all future writes to
|
||||
// fail.
|
||||
type Conn struct { |
||||
conn net.Conn |
||||
peer key.Public |
||||
handshakeHash [blake2s.Size]byte |
||||
rx rxState |
||||
tx txState |
||||
} |
||||
|
||||
// rxState is all the Conn state that Read uses.
|
||||
type rxState struct { |
||||
sync.Mutex |
||||
cipher cipher.AEAD |
||||
nonce [chp.NonceSize]byte |
||||
buf [maxPacketSize]byte |
||||
n int // number of valid bytes in buf
|
||||
next int // offset of next undecrypted packet
|
||||
plaintext []byte // slice into buf of decrypted bytes
|
||||
} |
||||
|
||||
// txState is all the Conn state that Write uses.
|
||||
type txState struct { |
||||
sync.Mutex |
||||
cipher cipher.AEAD |
||||
nonce [chp.NonceSize]byte |
||||
buf [maxPacketSize]byte |
||||
err error // records the first partial write error for all future calls
|
||||
} |
||||
|
||||
// HandshakeHash returns the Noise handshake hash for the connection,
|
||||
// which can be used to bind other messages to this connection
|
||||
// (i.e. to ensure that the message wasn't replayed from a different
|
||||
// connection).
|
||||
func (c *Conn) HandshakeHash() [blake2s.Size]byte { |
||||
return c.handshakeHash |
||||
} |
||||
|
||||
// Peer returns the peer's long-term public key.
|
||||
func (c *Conn) Peer() key.Public { |
||||
return c.peer |
||||
} |
||||
|
||||
// validNonce reports whether nonce is in the valid range for use: 0
|
||||
// through 2^64-2.
|
||||
func validNonce(nonce []byte) bool { |
||||
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce |
||||
} |
||||
|
||||
// readNLocked reads into c.rxBuf until rxBuf contains at least total
|
||||
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
||||
// error if fewer than total bytes are available.
|
||||
func (c *Conn) readNLocked(total int) ([]byte, error) { |
||||
if total > maxPacketSize { |
||||
return nil, errReadTooBig{total} |
||||
} |
||||
for { |
||||
if total <= c.rx.n { |
||||
return c.rx.buf[:c.rx.n], nil |
||||
} |
||||
|
||||
n, err := c.conn.Read(c.rx.buf[c.rx.n:]) |
||||
c.rx.n += n |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
|
||||
// to the decrypted bytes. Returns an error if the cipher is exhausted
|
||||
// (i.e. can no longer be used safely) or decryption fails.
|
||||
func (c *Conn) decryptLocked(ciphertext []byte) (err error) { |
||||
if !validNonce(c.rx.nonce[:]) { |
||||
return errCipherExhausted{} |
||||
} |
||||
|
||||
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) |
||||
|
||||
// Safe to increment the nonce here, because we checked for nonce
|
||||
// wraparound above.
|
||||
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:])) |
||||
|
||||
if err != nil { |
||||
// Once a decryption has failed, our Conn is no longer
|
||||
// synchronized with our peer. Nuke the cipher state to be
|
||||
// safe, so that no further decryptions are attempted.
|
||||
c.rx.cipher = nil |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
||||
// 2-byte length header) and returns a slice of the ciphertext, or an
|
||||
// error if the cipher is exhausted (i.e. can no longer be used safely).
|
||||
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { |
||||
if !validNonce(c.tx.nonce[:]) { |
||||
// Received 2^64-1 messages on this cipher state. Connection
|
||||
// is no longer usable.
|
||||
return nil, errCipherExhausted{} |
||||
} |
||||
|
||||
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize)) |
||||
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil) |
||||
|
||||
// Safe to increment the nonce here, because we checked for nonce
|
||||
// wraparound above.
|
||||
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:])) |
||||
|
||||
return ret, nil |
||||
} |
||||
|
||||
// wholeCiphertextLocked returns a slice of one whole Noise frame from
|
||||
// c.rx.buf, if one whole ciphertext is available, and advances the
|
||||
// read state to the next Noise frame in the buffer. Returns nil
|
||||
// without advancing read state if there's not one whole ciphertext in
|
||||
// c.rx.buf.
|
||||
func (c *Conn) wholeCiphertextLocked() []byte { |
||||
available := c.rx.n - c.rx.next |
||||
if available < 2 { |
||||
return nil |
||||
} |
||||
bs := c.rx.buf[c.rx.next:c.rx.n] |
||||
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2 |
||||
if len(bs) < totalSize { |
||||
return nil |
||||
} |
||||
c.rx.next += totalSize |
||||
return bs[:totalSize] |
||||
} |
||||
|
||||
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
|
||||
// and sets c.rx.plaintext to point to the decrypted
|
||||
// bytes. c.rx.plaintext is only valid if err == nil.
|
||||
func (c *Conn) decryptOneLocked() error { |
||||
c.rx.plaintext = nil |
||||
|
||||
// Fast path: do we have one whole ciphertext frame buffered
|
||||
// already?
|
||||
if bs := c.wholeCiphertextLocked(); bs != nil { |
||||
return c.decryptLocked(bs[2:]) |
||||
} |
||||
|
||||
if c.rx.next != 0 { |
||||
// To simplify the read logic, move the remainder of the
|
||||
// buffered bytes back to the head of the buffer, so we can
|
||||
// grow it without worrying about wraparound.
|
||||
copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) |
||||
c.rx.n -= c.rx.next |
||||
c.rx.next = 0 |
||||
} |
||||
|
||||
bs, err := c.readNLocked(2) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2 |
||||
bs, err = c.readNLocked(totalLen) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
c.rx.next = totalLen |
||||
bs = bs[2:totalLen] |
||||
|
||||
return c.decryptLocked(bs) |
||||
} |
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *Conn) Read(bs []byte) (int, error) { |
||||
c.rx.Lock() |
||||
defer c.rx.Unlock() |
||||
|
||||
if c.rx.cipher == nil { |
||||
return 0, net.ErrClosed |
||||
} |
||||
// Loop to handle receiving a zero-byte Noise message. Just skip
|
||||
// over it and keep decrypting until we find some bytes.
|
||||
for len(c.rx.plaintext) == 0 { |
||||
if err := c.decryptOneLocked(); err != nil { |
||||
return 0, err |
||||
} |
||||
} |
||||
n := copy(bs, c.rx.plaintext) |
||||
c.rx.plaintext = c.rx.plaintext[n:] |
||||
return n, nil |
||||
} |
||||
|
||||
// Write implements io.Writer.
|
||||
func (c *Conn) Write(bs []byte) (n int, err error) { |
||||
c.tx.Lock() |
||||
defer c.tx.Unlock() |
||||
|
||||
if c.tx.err != nil { |
||||
return 0, c.tx.err |
||||
} |
||||
defer func() { |
||||
if err != nil { |
||||
// All write errors are fatal for this conn, so clear the
|
||||
// cipher state whenever an error happens.
|
||||
c.tx.cipher = nil |
||||
} |
||||
if c.tx.err == nil { |
||||
// Only set c.tx.err if not nil so that we can return one
|
||||
// error on the first failure, and a different one for
|
||||
// subsequent calls. See the error handling around Write
|
||||
// below for why.
|
||||
c.tx.err = err |
||||
} |
||||
}() |
||||
|
||||
if c.tx.cipher == nil { |
||||
return 0, net.ErrClosed |
||||
} |
||||
|
||||
var sent int |
||||
for len(bs) > 0 { |
||||
toSend := bs |
||||
if len(toSend) > maxPlaintextSize { |
||||
toSend = bs[:maxPlaintextSize] |
||||
} |
||||
bs = bs[len(toSend):] |
||||
|
||||
ciphertext, err := c.encryptLocked(toSend) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
if n, err := c.conn.Write(ciphertext); err != nil { |
||||
sent += n |
||||
// Return the raw error on the Write that actually
|
||||
// failed. For future writes, return that error wrapped in
|
||||
// a desync error.
|
||||
c.tx.err = errPartialWrite{err} |
||||
return sent, err |
||||
} |
||||
sent += len(toSend) |
||||
} |
||||
return sent, nil |
||||
} |
||||
|
||||
// Close implements io.Closer.
|
||||
func (c *Conn) Close() error { |
||||
closeErr := c.conn.Close() // unblocks any waiting reads or writes
|
||||
c.rx.Lock() |
||||
c.rx.cipher = nil |
||||
c.rx.Unlock() |
||||
c.tx.Lock() |
||||
c.tx.cipher = nil |
||||
c.tx.Unlock() |
||||
return closeErr |
||||
} |
||||
|
||||
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } |
||||
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } |
||||
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } |
||||
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } |
||||
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } |
||||
|
||||
// errCipherExhausted is the error returned when we run out of nonces
|
||||
// on a cipher.
|
||||
type errCipherExhausted struct{} |
||||
|
||||
func (errCipherExhausted) Error() string { |
||||
return "cipher exhausted, no more nonces available for current key" |
||||
} |
||||
func (errCipherExhausted) Timeout() bool { return false } |
||||
func (errCipherExhausted) Temporary() bool { return false } |
||||
|
||||
// errPartialWrite is the error returned when the cipher state has
|
||||
// become unusable due to a past partial write.
|
||||
type errPartialWrite struct { |
||||
err error |
||||
} |
||||
|
||||
func (e errPartialWrite) Error() string { |
||||
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) |
||||
} |
||||
func (e errPartialWrite) Unwrap() error { return e.err } |
||||
func (e errPartialWrite) Temporary() bool { return false } |
||||
func (e errPartialWrite) Timeout() bool { return false } |
||||
|
||||
// errReadTooBig is the error returned when the peer sent an
|
||||
// unacceptably large Noise frame.
|
||||
type errReadTooBig struct { |
||||
requested int |
||||
} |
||||
|
||||
func (e errReadTooBig) Error() string { |
||||
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) |
||||
} |
||||
func (e errReadTooBig) Temporary() bool { |
||||
// permanent error because this error only occurs when our peer
|
||||
// sends us a frame so large we're unwilling to ever decode it.
|
||||
return false |
||||
} |
||||
func (e errReadTooBig) Timeout() bool { return false } |
||||
@ -0,0 +1,339 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"context" |
||||
"crypto/rand" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
"testing/iotest" |
||||
|
||||
chp "golang.org/x/crypto/chacha20poly1305" |
||||
"golang.org/x/net/nettest" |
||||
tsnettest "tailscale.com/net/nettest" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
func TestMessageSize(t *testing.T) { |
||||
// This test is a regression guard against someone looking at
|
||||
// maxCiphertextSize, going "huh, we could be more efficient if it
|
||||
// were larger, and accidentally violating the Noise spec. Do not
|
||||
// change this max value, it's a deliberate limitation of the
|
||||
// cryptographic protocol we use (see Section 3 "Message Format"
|
||||
// of the Noise spec).
|
||||
const max = 65535 |
||||
if maxCiphertextSize > max { |
||||
t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max) |
||||
} |
||||
} |
||||
|
||||
func TestConnBasic(t *testing.T) { |
||||
client, server := pair(t) |
||||
|
||||
sb := sinkReads(server) |
||||
|
||||
want := "test" |
||||
if _, err := io.WriteString(client, want); err != nil { |
||||
t.Fatalf("client write failed: %v", err) |
||||
} |
||||
client.Close() |
||||
|
||||
if got := sb.String(4); got != want { |
||||
t.Fatalf("wrong content received: got %q, want %q", got, want) |
||||
} |
||||
if err := sb.Error(); err != io.EOF { |
||||
t.Fatal("client close wasn't seen by server") |
||||
} |
||||
if sb.Total() != 4 { |
||||
t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total()) |
||||
} |
||||
} |
||||
|
||||
// bufferedWriteConn wraps a net.Conn and gives control over how
|
||||
// Writes get batched out.
|
||||
type bufferedWriteConn struct { |
||||
net.Conn |
||||
w *bufio.Writer |
||||
manualFlush bool |
||||
} |
||||
|
||||
func (c *bufferedWriteConn) Write(bs []byte) (int, error) { |
||||
n, err := c.w.Write(bs) |
||||
if err == nil && !c.manualFlush { |
||||
err = c.w.Flush() |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// TestFastPath exercises the Read codepath that can receive multiple
|
||||
// Noise frames at once and decode each in turn without making another
|
||||
// syscall.
|
||||
func TestFastPath(t *testing.T) { |
||||
s1, s2 := tsnettest.NewConn("noise", 128000) |
||||
b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false} |
||||
client, server := pairWithConns(t, b, s2) |
||||
|
||||
b.manualFlush = true |
||||
|
||||
sb := sinkReads(server) |
||||
|
||||
const packets = 10 |
||||
s := "test" |
||||
for i := 0; i < packets; i++ { |
||||
// Many separate writes, to force separate Noise frames that
|
||||
// all get buffered up and then all sent as a single slice to
|
||||
// the server.
|
||||
if _, err := io.WriteString(client, s); err != nil { |
||||
t.Fatalf("client write1 failed: %v", err) |
||||
} |
||||
} |
||||
if err := b.w.Flush(); err != nil { |
||||
t.Fatalf("client flush failed: %v", err) |
||||
} |
||||
client.Close() |
||||
|
||||
want := strings.Repeat(s, packets) |
||||
if got := sb.String(len(want)); got != want { |
||||
t.Fatalf("wrong content received: got %q, want %q", got, want) |
||||
} |
||||
if err := sb.Error(); err != io.EOF { |
||||
t.Fatalf("client close wasn't seen by server") |
||||
} |
||||
} |
||||
|
||||
// Writes things larger than a single Noise frame, to check the
|
||||
// chunking on the encoder and decoder.
|
||||
func TestBigData(t *testing.T) { |
||||
client, server := pair(t) |
||||
|
||||
serverReads := sinkReads(server) |
||||
clientReads := sinkReads(client) |
||||
|
||||
const sz = 15 * 1024 // 15KiB
|
||||
clientStr := strings.Repeat("abcde", sz/5) |
||||
serverStr := strings.Repeat("fghij", sz/5*2) |
||||
|
||||
if _, err := io.WriteString(client, clientStr); err != nil { |
||||
t.Fatalf("writing client>server: %v", err) |
||||
} |
||||
if _, err := io.WriteString(server, serverStr); err != nil { |
||||
t.Fatalf("writing server>client: %v", err) |
||||
} |
||||
|
||||
if serverGot := serverReads.String(sz); serverGot != clientStr { |
||||
t.Error("server didn't receive what client sent") |
||||
} |
||||
if clientGot := clientReads.String(2 * sz); clientGot != serverStr { |
||||
t.Error("client didn't receive what server sent") |
||||
} |
||||
|
||||
getNonce := func(n [chp.NonceSize]byte) uint64 { |
||||
if binary.BigEndian.Uint32(n[:4]) != 0 { |
||||
panic("unexpected nonce") |
||||
} |
||||
return binary.BigEndian.Uint64(n[4:]) |
||||
} |
||||
|
||||
// Reach into the Conns and verify the cipher nonces advanced as
|
||||
// expected.
|
||||
if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) { |
||||
t.Error("desynchronized client tx nonce") |
||||
} |
||||
if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) { |
||||
t.Error("desynchronized server tx nonce") |
||||
} |
||||
if n := getNonce(client.tx.nonce); n != 4 { |
||||
t.Errorf("wrong client tx nonce, got %d want 4", n) |
||||
} |
||||
if n := getNonce(server.tx.nonce); n != 8 { |
||||
t.Errorf("wrong client tx nonce, got %d want 8", n) |
||||
} |
||||
} |
||||
|
||||
// readerConn wraps a net.Conn and routes its Reads through a separate
|
||||
// io.Reader.
|
||||
type readerConn struct { |
||||
net.Conn |
||||
r io.Reader |
||||
} |
||||
|
||||
func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) } |
||||
|
||||
// Check that the receiver can handle not being able to read an entire
|
||||
// frame in a single syscall.
|
||||
func TestDataTrickle(t *testing.T) { |
||||
s1, s2 := tsnettest.NewConn("noise", 128000) |
||||
client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)}) |
||||
serverReads := sinkReads(server) |
||||
|
||||
const sz = 10000 |
||||
clientStr := strings.Repeat("abcde", sz/5) |
||||
if _, err := io.WriteString(client, clientStr); err != nil { |
||||
t.Fatalf("writing client>server: %v", err) |
||||
} |
||||
|
||||
serverGot := serverReads.String(sz) |
||||
if serverGot != clientStr { |
||||
t.Error("server didn't receive what client sent") |
||||
} |
||||
} |
||||
|
||||
func TestConnStd(t *testing.T) { |
||||
// You can run this test manually, and noise.Conn should pass all
|
||||
// of them except for TestConn/PastTimeout,
|
||||
// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
|
||||
// those tests assume that write errors are recoverable, and
|
||||
// they're not on our Conn due to cipher security.
|
||||
t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977") |
||||
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { |
||||
s1, s2 := tsnettest.NewConn("noise", 4096) |
||||
controlKey := key.NewPrivate() |
||||
machineKey := key.NewPrivate() |
||||
serverErr := make(chan error, 1) |
||||
go func() { |
||||
var err error |
||||
c2, err = Server(context.Background(), s2, controlKey) |
||||
serverErr <- err |
||||
}() |
||||
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) |
||||
if err != nil { |
||||
s1.Close() |
||||
s2.Close() |
||||
return nil, nil, nil, fmt.Errorf("connecting client: %w", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
c1.Close() |
||||
s1.Close() |
||||
s2.Close() |
||||
return nil, nil, nil, fmt.Errorf("connecting server: %w", err) |
||||
} |
||||
return c1, c2, func() { |
||||
c1.Close() |
||||
c2.Close() |
||||
}, nil |
||||
}) |
||||
} |
||||
|
||||
// mkConns creates synthetic Noise Conns wrapping the given net.Conns.
|
||||
// This function is for testing just the Conn transport logic without
|
||||
// having to muck about with Noise handshakes.
|
||||
func mkConns(s1, s2 net.Conn) (*Conn, *Conn) { |
||||
var k1, k2 [chp.KeySize]byte |
||||
if _, err := rand.Read(k1[:]); err != nil { |
||||
panic(err) |
||||
} |
||||
if _, err := rand.Read(k2[:]); err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
ret1 := &Conn{ |
||||
conn: s1, |
||||
tx: txState{cipher: newCHP(k1)}, |
||||
rx: rxState{cipher: newCHP(k2)}, |
||||
} |
||||
ret2 := &Conn{ |
||||
conn: s2, |
||||
tx: txState{cipher: newCHP(k2)}, |
||||
rx: rxState{cipher: newCHP(k1)}, |
||||
} |
||||
|
||||
return ret1, ret2 |
||||
} |
||||
|
||||
type readSink struct { |
||||
r io.Reader |
||||
|
||||
cond *sync.Cond |
||||
sync.Mutex |
||||
bs bytes.Buffer |
||||
err error |
||||
} |
||||
|
||||
func sinkReads(r io.Reader) *readSink { |
||||
ret := &readSink{ |
||||
r: r, |
||||
} |
||||
ret.cond = sync.NewCond(&ret.Mutex) |
||||
go func() { |
||||
var buf [4096]byte |
||||
for { |
||||
n, err := r.Read(buf[:]) |
||||
ret.Lock() |
||||
ret.bs.Write(buf[:n]) |
||||
if err != nil { |
||||
ret.err = err |
||||
} |
||||
ret.cond.Broadcast() |
||||
ret.Unlock() |
||||
if err != nil { |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
return ret |
||||
} |
||||
|
||||
func (s *readSink) String(total int) string { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
for s.bs.Len() < total && s.err == nil { |
||||
s.cond.Wait() |
||||
} |
||||
if s.err != nil { |
||||
total = s.bs.Len() |
||||
} |
||||
return string(s.bs.Bytes()[:total]) |
||||
} |
||||
|
||||
func (s *readSink) Error() error { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
for s.err == nil { |
||||
s.cond.Wait() |
||||
} |
||||
return s.err |
||||
} |
||||
|
||||
func (s *readSink) Total() int { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
return s.bs.Len() |
||||
} |
||||
|
||||
func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) { |
||||
var ( |
||||
controlKey = key.NewPrivate() |
||||
machineKey = key.NewPrivate() |
||||
server *Conn |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
var err error |
||||
server, err = Server(context.Background(), serverConn, controlKey) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public()) |
||||
if err != nil { |
||||
t.Fatalf("client connection failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server connection failed: %v", err) |
||||
} |
||||
return client, server |
||||
} |
||||
|
||||
func pair(t *testing.T) (*Conn, *Conn) { |
||||
s1, s2 := tsnettest.NewConn("noise", 128000) |
||||
return pairWithConns(t, s1, s2) |
||||
} |
||||
@ -0,0 +1,361 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/cipher" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"hash" |
||||
"io" |
||||
"net" |
||||
"time" |
||||
|
||||
"golang.org/x/crypto/blake2s" |
||||
chp "golang.org/x/crypto/chacha20poly1305" |
||||
"golang.org/x/crypto/curve25519" |
||||
"golang.org/x/crypto/hkdf" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
const ( |
||||
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" |
||||
invalidNonce = ^uint64(0) |
||||
) |
||||
|
||||
// Client initiates a Noise client handshake, returning the resulting
|
||||
// Noise connection.
|
||||
//
|
||||
// The context deadline, if any, covers the entire handshaking
|
||||
// process.
|
||||
func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) { |
||||
if deadline, ok := ctx.Deadline(); ok { |
||||
if err := conn.SetDeadline(deadline); err != nil { |
||||
return nil, fmt.Errorf("setting conn deadline: %w", err) |
||||
} |
||||
defer func() { |
||||
conn.SetDeadline(time.Time{}) |
||||
}() |
||||
} |
||||
|
||||
var s symmetricState |
||||
s.Initialize() |
||||
|
||||
// <- s
|
||||
// ...
|
||||
s.MixHash(controlKey[:]) |
||||
|
||||
var init initiationMessage |
||||
// -> e, es, s, ss
|
||||
machineEphemeral := key.NewPrivate() |
||||
machineEphemeralPub := machineEphemeral.Public() |
||||
copy(init.MachineEphemeralPub(), machineEphemeralPub[:]) |
||||
s.MixHash(machineEphemeralPub[:]) |
||||
if err := s.MixDH(machineEphemeral, controlKey); err != nil { |
||||
return nil, fmt.Errorf("computing es: %w", err) |
||||
} |
||||
machineKeyPub := machineKey.Public() |
||||
copy(init.MachinePub(), s.EncryptAndHash(machineKeyPub[:])) |
||||
if err := s.MixDH(machineKey, controlKey); err != nil { |
||||
return nil, fmt.Errorf("computing ss: %w", err) |
||||
} |
||||
copy(init.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||
|
||||
if _, err := conn.Write(init[:]); err != nil { |
||||
return nil, fmt.Errorf("writing initiation: %w", err) |
||||
} |
||||
|
||||
// <- e, ee, se
|
||||
var resp responseMessage |
||||
if _, err := io.ReadFull(conn, resp[:]); err != nil { |
||||
return nil, fmt.Errorf("reading response: %w", err) |
||||
} |
||||
|
||||
var controlEphemeralPub key.Public |
||||
copy(controlEphemeralPub[:], resp.ControlEphemeralPub()) |
||||
s.MixHash(controlEphemeralPub[:]) |
||||
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { |
||||
return nil, fmt.Errorf("computing ee: %w", err) |
||||
} |
||||
if err := s.MixDH(machineKey, controlEphemeralPub); err != nil { |
||||
return nil, fmt.Errorf("computing se: %w", err) |
||||
} |
||||
if _, err := s.DecryptAndHash(resp.Tag()); err != nil { |
||||
return nil, fmt.Errorf("decrypting payload: %w", err) |
||||
} |
||||
|
||||
c1, c2, err := s.Split() |
||||
if err != nil { |
||||
return nil, fmt.Errorf("finalizing handshake: %w", err) |
||||
} |
||||
|
||||
return &Conn{ |
||||
conn: conn, |
||||
peer: controlKey, |
||||
handshakeHash: s.h, |
||||
tx: txState{ |
||||
cipher: c1, |
||||
}, |
||||
rx: rxState{ |
||||
cipher: c2, |
||||
}, |
||||
}, nil |
||||
} |
||||
|
||||
// Server initiates a Noise server handshake, returning the resulting
|
||||
// Noise connection.
|
||||
//
|
||||
// The context deadline, if any, covers the entire handshaking
|
||||
// process.
|
||||
func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, error) { |
||||
if deadline, ok := ctx.Deadline(); ok { |
||||
if err := conn.SetDeadline(deadline); err != nil { |
||||
return nil, fmt.Errorf("setting conn deadline: %w", err) |
||||
} |
||||
defer func() { |
||||
conn.SetDeadline(time.Time{}) |
||||
}() |
||||
} |
||||
|
||||
var s symmetricState |
||||
s.Initialize() |
||||
|
||||
// <- s
|
||||
// ...
|
||||
controlKeyPub := controlKey.Public() |
||||
s.MixHash(controlKeyPub[:]) |
||||
|
||||
// -> e, es, s, ss
|
||||
var init initiationMessage |
||||
if _, err := io.ReadFull(conn, init[:]); err != nil { |
||||
return nil, fmt.Errorf("reading initiation: %w", err) |
||||
} |
||||
|
||||
var machineEphemeralPub key.Public |
||||
copy(machineEphemeralPub[:], init.MachineEphemeralPub()) |
||||
s.MixHash(machineEphemeralPub[:]) |
||||
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil { |
||||
return nil, fmt.Errorf("computing es: %w", err) |
||||
} |
||||
var machineKey key.Public |
||||
rs, err := s.DecryptAndHash(init.MachinePub()) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("decrypting machine key: %w", err) |
||||
} |
||||
copy(machineKey[:], rs) |
||||
if err := s.MixDH(controlKey, machineKey); err != nil { |
||||
return nil, fmt.Errorf("computing ss: %w", err) |
||||
} |
||||
if _, err := s.DecryptAndHash(init.Tag()); err != nil { |
||||
return nil, fmt.Errorf("decrypting initiation tag: %w", err) |
||||
} |
||||
|
||||
// <- e, ee, se
|
||||
var resp responseMessage |
||||
controlEphemeral := key.NewPrivate() |
||||
controlEphemeralPub := controlEphemeral.Public() |
||||
copy(resp.ControlEphemeralPub(), controlEphemeralPub[:]) |
||||
s.MixHash(controlEphemeralPub[:]) |
||||
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { |
||||
return nil, fmt.Errorf("computing ee: %w", err) |
||||
} |
||||
if err := s.MixDH(controlEphemeral, machineKey); err != nil { |
||||
return nil, fmt.Errorf("computing se: %w", err) |
||||
} |
||||
copy(resp.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||
|
||||
c1, c2, err := s.Split() |
||||
if err != nil { |
||||
return nil, fmt.Errorf("finalizing handshake: %w", err) |
||||
} |
||||
|
||||
if _, err := conn.Write(resp[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &Conn{ |
||||
conn: conn, |
||||
peer: machineKey, |
||||
handshakeHash: s.h, |
||||
tx: txState{ |
||||
cipher: c2, |
||||
}, |
||||
rx: rxState{ |
||||
cipher: c1, |
||||
}, |
||||
}, nil |
||||
} |
||||
|
||||
// initiationMessage is the Noise protocol message sent from a client
|
||||
// machine to a control server.
|
||||
type initiationMessage [96]byte |
||||
|
||||
func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] } |
||||
func (m *initiationMessage) MachinePub() []byte { return m[32:80] } |
||||
func (m *initiationMessage) Tag() []byte { return m[80:] } |
||||
|
||||
// responseMessage is the Noise protocol message sent from a control
|
||||
// server to a client machine.
|
||||
type responseMessage [48]byte |
||||
|
||||
func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] } |
||||
func (m *responseMessage) Tag() []byte { return m[32:] } |
||||
|
||||
// symmetricState is the SymmetricState object from the Noise protocol
|
||||
// spec. It contains all the symmetric cipher state of an in-flight
|
||||
// handshake. Field names match the variable names in the spec.
|
||||
type symmetricState struct { |
||||
h [blake2s.Size]byte |
||||
ck [blake2s.Size]byte |
||||
|
||||
k [chp.KeySize]byte |
||||
n uint64 |
||||
|
||||
mixer hash.Hash // for updating h
|
||||
} |
||||
|
||||
// Initialize sets s to the initial handshake state, prior to
|
||||
// processing any Noise messages.
|
||||
func (s *symmetricState) Initialize() { |
||||
if s.mixer != nil { |
||||
panic("symmetricState cannot be reused") |
||||
} |
||||
s.h = blake2s.Sum256([]byte(protocolName)) |
||||
s.ck = s.h |
||||
s.k = [chp.KeySize]byte{} |
||||
s.n = invalidNonce |
||||
s.mixer = newBLAKE2s() |
||||
// Mix in an empty prologue.
|
||||
s.MixHash(nil) |
||||
} |
||||
|
||||
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
||||
// concatenation.
|
||||
func (s *symmetricState) MixHash(data []byte) { |
||||
s.mixer.Reset() |
||||
s.mixer.Write(s.h[:]) |
||||
s.mixer.Write(data) |
||||
s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly...
|
||||
} |
||||
|
||||
// MixDH updates s.ck and s.k with the result of X25519(priv, pub).
|
||||
//
|
||||
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
|
||||
// it as a single function allows for strongly-typed arguments that
|
||||
// reduce the risk of error in the caller (e.g. invoking X25519 with
|
||||
// two private keys, or two public keys), and thus producing the wrong
|
||||
// calculation.
|
||||
func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error { |
||||
// TODO(danderson): check that this operation is correct. The docs
|
||||
// for X25519 say that the 2nd arg must be either Basepoint or the
|
||||
// output of another X25519 call.
|
||||
//
|
||||
// I think this is correct, because pub is the result of a
|
||||
// ScalarBaseMult on the private key, and our private key
|
||||
// generation code clamps keys to avoid low order points. I
|
||||
// believe that makes pub equivalent to the output of
|
||||
// X25519(privateKey, Basepoint), and so the contract is
|
||||
// respected.
|
||||
keyData, err := curve25519.X25519(priv[:], pub[:]) |
||||
if err != nil { |
||||
return fmt.Errorf("computing X25519: %w", err) |
||||
} |
||||
|
||||
r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) |
||||
if _, err := io.ReadFull(r, s.ck[:]); err != nil { |
||||
return fmt.Errorf("extracting ck: %w", err) |
||||
} |
||||
if _, err := io.ReadFull(r, s.k[:]); err != nil { |
||||
return fmt.Errorf("extracting k: %w", err) |
||||
} |
||||
s.n = 0 |
||||
return nil |
||||
} |
||||
|
||||
// EncryptAndHash encrypts the given plaintext using the current s.k,
|
||||
// mixes the ciphertext into s.h, and returns the ciphertext.
|
||||
func (s *symmetricState) EncryptAndHash(plaintext []byte) []byte { |
||||
if s.n == invalidNonce { |
||||
// Noise in general permits writing "ciphertext" without a
|
||||
// key, but in IK it cannot happen.
|
||||
panic("attempted encryption with uninitialized key") |
||||
} |
||||
aead := newCHP(s.k) |
||||
var nonce [chp.NonceSize]byte |
||||
binary.BigEndian.PutUint64(nonce[4:], s.n) |
||||
s.n++ |
||||
ret := aead.Seal(nil, nonce[:], plaintext, s.h[:]) |
||||
s.MixHash(ret) |
||||
return ret |
||||
} |
||||
|
||||
// DecryptAndHash decrypts the given ciphertext using the current
|
||||
// s.k. If decryption is successful, it mixes the ciphertext into s.h
|
||||
// and returns the plaintext.
|
||||
func (s *symmetricState) DecryptAndHash(ciphertext []byte) ([]byte, error) { |
||||
if s.n == invalidNonce { |
||||
// Noise in general permits "ciphertext" without a key, but in
|
||||
// IK it cannot happen.
|
||||
panic("attempted encryption with uninitialized key") |
||||
} |
||||
aead := newCHP(s.k) |
||||
var nonce [chp.NonceSize]byte |
||||
binary.BigEndian.PutUint64(nonce[4:], s.n) |
||||
s.n++ |
||||
ret, err := aead.Open(nil, nonce[:], ciphertext, s.h[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
s.MixHash(ciphertext) |
||||
return ret, nil |
||||
} |
||||
|
||||
// Split returns two ChaCha20Poly1305 ciphers with keys derives from
|
||||
// the current handshake state. Methods on s must not be used again
|
||||
// after calling Split().
|
||||
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { |
||||
var k1, k2 [chp.KeySize]byte |
||||
r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) |
||||
if _, err := io.ReadFull(r, k1[:]); err != nil { |
||||
return nil, nil, fmt.Errorf("extracting k1: %w", err) |
||||
} |
||||
if _, err := io.ReadFull(r, k2[:]); err != nil { |
||||
return nil, nil, fmt.Errorf("extracting k2: %w", err) |
||||
} |
||||
c1, err = chp.New(k1[:]) |
||||
if err != nil { |
||||
return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) |
||||
} |
||||
c2, err = chp.New(k2[:]) |
||||
if err != nil { |
||||
return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) |
||||
} |
||||
return c1, c2, nil |
||||
} |
||||
|
||||
// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
|
||||
// error.
|
||||
func newBLAKE2s() hash.Hash { |
||||
h, err := blake2s.New256(nil) |
||||
if err != nil { |
||||
// Should never happen, errors only happen when using BLAKE2s
|
||||
// in MAC mode with a key.
|
||||
panic(fmt.Sprintf("blake2s construction: %v", err)) |
||||
} |
||||
return h |
||||
} |
||||
|
||||
// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
|
||||
// panics on error.
|
||||
func newCHP(key [chp.KeySize]byte) cipher.AEAD { |
||||
aead, err := chp.New(key[:]) |
||||
if err != nil { |
||||
// Can only happen if we passed a key of the wrong length. The
|
||||
// function signature prevents that.
|
||||
panic(fmt.Sprintf("chacha20poly1305 construction: %v", err)) |
||||
} |
||||
return aead |
||||
} |
||||
@ -0,0 +1,290 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"io" |
||||
"strings" |
||||
"testing" |
||||
"time" |
||||
|
||||
tsnettest "tailscale.com/net/nettest" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
func TestHandshake(t *testing.T) { |
||||
var ( |
||||
clientConn, serverConn = tsnettest.NewConn("noise", 128000) |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
server *Conn |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
var err error |
||||
server, err = Server(context.Background(), serverConn, serverKey) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err != nil { |
||||
t.Fatalf("client connection failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server connection failed: %v", err) |
||||
} |
||||
|
||||
if client.HandshakeHash() != server.HandshakeHash() { |
||||
t.Fatal("client and server disagree on handshake hash") |
||||
} |
||||
|
||||
if client.Peer() != serverKey.Public() { |
||||
t.Fatal("client peer key isn't serverKey") |
||||
} |
||||
if server.Peer() != clientKey.Public() { |
||||
t.Fatal("client peer key isn't serverKey") |
||||
} |
||||
} |
||||
|
||||
// Check that handshaking repeatedly with the same long-term keys
|
||||
// result in different handshake hashes and wire traffic.
|
||||
func TestNoReuse(t *testing.T) { |
||||
var ( |
||||
hashes = map[[32]byte]bool{} |
||||
clientHandshakes = map[[96]byte]bool{} |
||||
serverHandshakes = map[[48]byte]bool{} |
||||
packets = map[[32]byte]bool{} |
||||
) |
||||
for i := 0; i < 10; i++ { |
||||
var ( |
||||
clientRaw, serverRaw = tsnettest.NewConn("noise", 128000) |
||||
clientBuf, serverBuf bytes.Buffer |
||||
clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)} |
||||
serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)} |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
server *Conn |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
var err error |
||||
server, err = Server(context.Background(), serverConn, serverKey) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err != nil { |
||||
t.Fatalf("client connection failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server connection failed: %v", err) |
||||
} |
||||
|
||||
var clientHS [96]byte |
||||
copy(clientHS[:], serverBuf.Bytes()) |
||||
if clientHandshakes[clientHS] { |
||||
t.Fatal("client handshake seen twice") |
||||
} |
||||
clientHandshakes[clientHS] = true |
||||
|
||||
var serverHS [48]byte |
||||
copy(serverHS[:], clientBuf.Bytes()) |
||||
if serverHandshakes[serverHS] { |
||||
t.Fatal("server handshake seen twice") |
||||
} |
||||
serverHandshakes[serverHS] = true |
||||
|
||||
clientBuf.Reset() |
||||
serverBuf.Reset() |
||||
cb := sinkReads(client) |
||||
sb := sinkReads(server) |
||||
|
||||
if hashes[client.HandshakeHash()] { |
||||
t.Fatalf("handshake hash %v seen twice", client.HandshakeHash()) |
||||
} |
||||
hashes[client.HandshakeHash()] = true |
||||
|
||||
// Sending 14 bytes turns into 32 bytes on the wire (+16 for
|
||||
// the poly1305 tag, +2 length header)
|
||||
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil { |
||||
t.Fatalf("client>server write failed: %v", err) |
||||
} |
||||
if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil { |
||||
t.Fatalf("server>client write failed: %v", err) |
||||
} |
||||
|
||||
// Wait for the bytes to be read, so we know they've traveled end to end
|
||||
cb.String(14) |
||||
sb.String(14) |
||||
|
||||
var clientWire, serverWire [32]byte |
||||
copy(clientWire[:], clientBuf.Bytes()) |
||||
copy(serverWire[:], serverBuf.Bytes()) |
||||
|
||||
if packets[clientWire] { |
||||
t.Fatalf("client wire traffic seen twice") |
||||
} |
||||
packets[clientWire] = true |
||||
if packets[serverWire] { |
||||
t.Fatalf("server wire traffic seen twice") |
||||
} |
||||
packets[serverWire] = true |
||||
} |
||||
} |
||||
|
||||
// tamperReader wraps a reader and mutates the Nth byte.
|
||||
type tamperReader struct { |
||||
r io.Reader |
||||
n int |
||||
total int |
||||
} |
||||
|
||||
func (r *tamperReader) Read(bs []byte) (int, error) { |
||||
n, err := r.r.Read(bs) |
||||
if off := r.n - r.total; off >= 0 && off < n { |
||||
bs[off] += 1 |
||||
} |
||||
r.total += n |
||||
return n, err |
||||
} |
||||
|
||||
func TestTampering(t *testing.T) { |
||||
// Tamper with every byte of the client initiation message.
|
||||
for i := 0; i < 96; i++ { |
||||
var ( |
||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000) |
||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}} |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
_, err := Server(context.Background(), serverConn, serverKey) |
||||
// If the server failed, we have to close the Conn to
|
||||
// unblock the client.
|
||||
if err != nil { |
||||
serverConn.Close() |
||||
} |
||||
serverErr <- err |
||||
}() |
||||
|
||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err == nil { |
||||
t.Fatal("client connection succeeded despite tampering") |
||||
} |
||||
if err := <-serverErr; err == nil { |
||||
t.Fatalf("server connection succeeded despite tampering") |
||||
} |
||||
} |
||||
|
||||
// Tamper with every byte of the server response message.
|
||||
for i := 0; i < 48; i++ { |
||||
var ( |
||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000) |
||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}} |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
_, err := Server(context.Background(), serverConn, serverKey) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err == nil { |
||||
t.Fatal("client connection succeeded despite tampering") |
||||
} |
||||
// The server shouldn't fail, because the tampering took place
|
||||
// in its response.
|
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server connection failed despite no tampering: %v", err) |
||||
} |
||||
} |
||||
|
||||
// Tamper with every byte of the first server>client transport message.
|
||||
for i := 0; i < 32; i++ { |
||||
var ( |
||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000) |
||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}} |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
server, err := Server(context.Background(), serverConn, serverKey) |
||||
serverErr <- err |
||||
_, err = io.WriteString(server, strings.Repeat("a", 14)) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err != nil { |
||||
t.Fatalf("client handshake failed: %v", err) |
||||
} |
||||
// The server shouldn't fail, because the tampering took place
|
||||
// in its response.
|
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server handshake failed: %v", err) |
||||
} |
||||
|
||||
// The client needs a timeout if the tampering is hitting the length header.
|
||||
if i == 0 || i == 1 { |
||||
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) |
||||
} |
||||
|
||||
var bs [100]byte |
||||
n, err := client.Read(bs[:]) |
||||
if err == nil { |
||||
t.Fatal("read succeeded despite tampering") |
||||
} |
||||
if n != 0 { |
||||
t.Fatal("conn yielded some bytes despite tampering") |
||||
} |
||||
} |
||||
|
||||
// Tamper with every byte of the first client>server transport message.
|
||||
for i := 0; i < 32; i++ { |
||||
var ( |
||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000) |
||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}} |
||||
serverKey = key.NewPrivate() |
||||
clientKey = key.NewPrivate() |
||||
serverErr = make(chan error, 1) |
||||
) |
||||
go func() { |
||||
server, err := Server(context.Background(), serverConn, serverKey) |
||||
serverErr <- err |
||||
var bs [100]byte |
||||
// The server needs a timeout if the tampering is hitting the length header.
|
||||
if i == 0 || i == 1 { |
||||
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) |
||||
} |
||||
n, err := server.Read(bs[:]) |
||||
if n != 0 { |
||||
panic("server got bytes despite tampering") |
||||
} else { |
||||
serverErr <- err |
||||
} |
||||
}() |
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) |
||||
if err != nil { |
||||
t.Fatalf("client handshake failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server handshake failed: %v", err) |
||||
} |
||||
|
||||
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil { |
||||
t.Fatalf("client>server write failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err == nil { |
||||
t.Fatal("server successfully received bytes despite tampering") |
||||
} |
||||
} |
||||
} |
||||
@ -0,0 +1,238 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/binary" |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
|
||||
tsnettest "tailscale.com/net/nettest" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
// Can a reference Noise IK client talk to our server?
|
||||
func TestInteropClient(t *testing.T) { |
||||
var ( |
||||
s1, s2 = tsnettest.NewConn("noise", 128000) |
||||
controlKey = key.NewPrivate() |
||||
machineKey = key.NewPrivate() |
||||
serverErr = make(chan error, 2) |
||||
serverBytes = make(chan []byte, 1) |
||||
c2s = "client>server" |
||||
s2c = "server>client" |
||||
) |
||||
|
||||
go func() { |
||||
server, err := Server(context.Background(), s2, controlKey) |
||||
serverErr <- err |
||||
if err != nil { |
||||
return |
||||
} |
||||
var buf [1024]byte |
||||
_, err = io.ReadFull(server, buf[:len(c2s)]) |
||||
serverBytes <- buf[:len(c2s)] |
||||
if err != nil { |
||||
serverErr <- err |
||||
return |
||||
} |
||||
_, err = server.Write([]byte(s2c)) |
||||
serverErr <- err |
||||
}() |
||||
|
||||
gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) |
||||
if err != nil { |
||||
t.Fatalf("failed client interop: %v", err) |
||||
} |
||||
if string(gotS2C) != s2c { |
||||
t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) |
||||
} |
||||
|
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server handshake failed: %v", err) |
||||
} |
||||
if err := <-serverErr; err != nil { |
||||
t.Fatalf("server read/write failed: %v", err) |
||||
} |
||||
if got := string(<-serverBytes); got != c2s { |
||||
t.Fatalf("server received %q, want %q", got, c2s) |
||||
} |
||||
} |
||||
|
||||
// Can our client talk to a reference Noise IK server?
|
||||
func TestInteropServer(t *testing.T) { |
||||
var ( |
||||
s1, s2 = tsnettest.NewConn("noise", 128000) |
||||
controlKey = key.NewPrivate() |
||||
machineKey = key.NewPrivate() |
||||
clientErr = make(chan error, 2) |
||||
clientBytes = make(chan []byte, 1) |
||||
c2s = "client>server" |
||||
s2c = "server>client" |
||||
) |
||||
|
||||
go func() { |
||||
client, err := Client(context.Background(), s1, machineKey, controlKey.Public()) |
||||
clientErr <- err |
||||
if err != nil { |
||||
return |
||||
} |
||||
_, err = client.Write([]byte(c2s)) |
||||
if err != nil { |
||||
clientErr <- err |
||||
return |
||||
} |
||||
var buf [1024]byte |
||||
_, err = io.ReadFull(client, buf[:len(s2c)]) |
||||
clientBytes <- buf[:len(s2c)] |
||||
clientErr <- err |
||||
}() |
||||
|
||||
gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) |
||||
if err != nil { |
||||
t.Fatalf("failed server interop: %v", err) |
||||
} |
||||
if string(gotC2S) != c2s { |
||||
t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) |
||||
} |
||||
|
||||
if err := <-clientErr; err != nil { |
||||
t.Fatalf("client handshake failed: %v", err) |
||||
} |
||||
if err := <-clientErr; err != nil { |
||||
t.Fatalf("client read/write failed: %v", err) |
||||
} |
||||
if got := string(<-clientBytes); got != s2c { |
||||
t.Fatalf("client received %q, want %q", got, s2c) |
||||
} |
||||
} |
||||
|
||||
// noiseExplorerClient uses the Noise Explorer implementation of Noise
|
||||
// IK to handshake as a Noise client on conn, transmit payload, and
|
||||
// read+return a payload from the peer.
|
||||
func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Private, payload []byte) ([]byte, error) { |
||||
mk := keypair{ |
||||
private_key: machineKey, |
||||
public_key: machineKey.Public(), |
||||
} |
||||
session := InitSession(true, nil, mk, controlKey) |
||||
|
||||
_, msg1 := SendMessage(&session, nil) |
||||
if _, err := conn.Write(msg1.ne[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := conn.Write(msg1.ns); err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := conn.Write(msg1.ciphertext); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var buf [1024]byte |
||||
if _, err := io.ReadFull(conn, buf[:48]); err != nil { |
||||
return nil, err |
||||
} |
||||
msg2 := messagebuffer{ |
||||
ciphertext: buf[32:48], |
||||
} |
||||
copy(msg2.ne[:], buf[:32]) |
||||
_, p, valid := RecvMessage(&session, &msg2) |
||||
if !valid { |
||||
return nil, errors.New("handshake failed") |
||||
} |
||||
if len(p) != 0 { |
||||
return nil, errors.New("non-empty payload") |
||||
} |
||||
|
||||
_, msg3 := SendMessage(&session, payload) |
||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext))) |
||||
if _, err := conn.Write(buf[:2]); err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := conn.Write(msg3.ciphertext); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil { |
||||
return nil, err |
||||
} |
||||
plen := int(binary.BigEndian.Uint16(buf[:2])) |
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
msg4 := messagebuffer{ |
||||
ciphertext: buf[:plen], |
||||
} |
||||
_, p, valid = RecvMessage(&session, &msg4) |
||||
if !valid { |
||||
return nil, errors.New("transport message decryption failed") |
||||
} |
||||
|
||||
return p, nil |
||||
} |
||||
|
||||
func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey key.Public, payload []byte) ([]byte, error) { |
||||
mk := keypair{ |
||||
private_key: controlKey, |
||||
public_key: controlKey.Public(), |
||||
} |
||||
session := InitSession(false, nil, mk, [32]byte{}) |
||||
|
||||
var buf [1024]byte |
||||
if _, err := io.ReadFull(conn, buf[:96]); err != nil { |
||||
return nil, err |
||||
} |
||||
msg1 := messagebuffer{ |
||||
ns: buf[32:80], |
||||
ciphertext: buf[80:96], |
||||
} |
||||
copy(msg1.ne[:], buf[:32]) |
||||
_, p, valid := RecvMessage(&session, &msg1) |
||||
if !valid { |
||||
return nil, errors.New("handshake failed") |
||||
} |
||||
if len(p) != 0 { |
||||
return nil, errors.New("non-empty payload") |
||||
} |
||||
|
||||
_, msg2 := SendMessage(&session, nil) |
||||
if _, err := conn.Write(msg2.ne[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := conn.Write(msg2.ciphertext[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil { |
||||
return nil, err |
||||
} |
||||
plen := int(binary.BigEndian.Uint16(buf[:2])) |
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
msg3 := messagebuffer{ |
||||
ciphertext: buf[:plen], |
||||
} |
||||
_, p, valid = RecvMessage(&session, &msg3) |
||||
if !valid { |
||||
return nil, errors.New("transport message decryption failed") |
||||
} |
||||
|
||||
_, msg4 := SendMessage(&session, payload) |
||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext))) |
||||
if _, err := conn.Write(buf[:2]); err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := conn.Write(msg4.ciphertext); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return p, nil |
||||
} |
||||
@ -0,0 +1,475 @@ |
||||
// This file contains the implementation of Noise IK from
|
||||
// https://noiseexplorer.com/ . Unlike the rest of this repository,
|
||||
// this file is licensed under the terms of the GNU GPL v3. See
|
||||
// https://source.symbolic.software/noiseexplorer/noiseexplorer for
|
||||
// more information.
|
||||
//
|
||||
// This file is used here to verify that Tailscale's implementation of
|
||||
// Noise IK is interoperable with another implementation.
|
||||
//lint:file-ignore SA4006 not our code.
|
||||
|
||||
/* |
||||
IK: |
||||
<- s |
||||
... |
||||
-> e, es, s, ss |
||||
<- e, ee, se |
||||
-> |
||||
<- |
||||
*/ |
||||
|
||||
// Implementation Version: 1.0.2
|
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* PARAMETERS * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
package noise |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"crypto/subtle" |
||||
"encoding/binary" |
||||
"hash" |
||||
"io" |
||||
"math" |
||||
|
||||
"golang.org/x/crypto/blake2s" |
||||
"golang.org/x/crypto/chacha20poly1305" |
||||
"golang.org/x/crypto/curve25519" |
||||
"golang.org/x/crypto/hkdf" |
||||
) |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* TYPES * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
type keypair struct { |
||||
public_key [32]byte |
||||
private_key [32]byte |
||||
} |
||||
|
||||
type messagebuffer struct { |
||||
ne [32]byte |
||||
ns []byte |
||||
ciphertext []byte |
||||
} |
||||
|
||||
type cipherstate struct { |
||||
k [32]byte |
||||
n uint32 |
||||
} |
||||
|
||||
type symmetricstate struct { |
||||
cs cipherstate |
||||
ck [32]byte |
||||
h [32]byte |
||||
} |
||||
|
||||
type handshakestate struct { |
||||
ss symmetricstate |
||||
s keypair |
||||
e keypair |
||||
rs [32]byte |
||||
re [32]byte |
||||
psk [32]byte |
||||
} |
||||
|
||||
type noisesession struct { |
||||
hs handshakestate |
||||
h [32]byte |
||||
cs1 cipherstate |
||||
cs2 cipherstate |
||||
mc uint64 |
||||
i bool |
||||
} |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* CONSTANTS * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
var emptyKey = [32]byte{ |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
0x00, 0x00, 0x00, 0x00, |
||||
} |
||||
|
||||
var minNonce = uint32(0) |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* UTILITY FUNCTIONS * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
func getPublicKey(kp *keypair) [32]byte { |
||||
return kp.public_key |
||||
} |
||||
|
||||
func isEmptyKey(k [32]byte) bool { |
||||
return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1 |
||||
} |
||||
|
||||
func validatePublicKey(k []byte) bool { |
||||
forbiddenCurveValues := [12][]byte{ |
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, |
||||
{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, |
||||
{224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0}, |
||||
{95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87}, |
||||
{236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, |
||||
{237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, |
||||
{238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, |
||||
{205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128}, |
||||
{76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215}, |
||||
{217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, |
||||
{218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, |
||||
{219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25}, |
||||
} |
||||
|
||||
for _, testValue := range forbiddenCurveValues { |
||||
if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 { |
||||
panic("Invalid public key") |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* PRIMITIVES * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
func incrementNonce(n uint32) uint32 { |
||||
return n + 1 |
||||
} |
||||
|
||||
func dh(private_key [32]byte, public_key [32]byte) [32]byte { |
||||
var ss [32]byte |
||||
curve25519.ScalarMult(&ss, &private_key, &public_key) |
||||
return ss |
||||
} |
||||
|
||||
func generateKeypair() keypair { |
||||
var public_key [32]byte |
||||
var private_key [32]byte |
||||
_, _ = rand.Read(private_key[:]) |
||||
curve25519.ScalarBaseMult(&public_key, &private_key) |
||||
if validatePublicKey(public_key[:]) { |
||||
return keypair{public_key, private_key} |
||||
} |
||||
return generateKeypair() |
||||
} |
||||
|
||||
func generatePublicKey(private_key [32]byte) [32]byte { |
||||
var public_key [32]byte |
||||
curve25519.ScalarBaseMult(&public_key, &private_key) |
||||
return public_key |
||||
} |
||||
|
||||
func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte { |
||||
var nonce [12]byte |
||||
var ciphertext []byte |
||||
enc, _ := chacha20poly1305.New(k[:]) |
||||
binary.LittleEndian.PutUint32(nonce[4:], n) |
||||
ciphertext = enc.Seal(nil, nonce[:], plaintext, ad) |
||||
return ciphertext |
||||
} |
||||
|
||||
func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) { |
||||
var nonce [12]byte |
||||
var plaintext []byte |
||||
enc, err := chacha20poly1305.New(k[:]) |
||||
binary.LittleEndian.PutUint32(nonce[4:], n) |
||||
plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad) |
||||
return (err == nil), ad, plaintext |
||||
} |
||||
|
||||
func getHash(a []byte, b []byte) [32]byte { |
||||
return blake2s.Sum256(append(a, b...)) |
||||
} |
||||
|
||||
func hashProtocolName(protocolName []byte) [32]byte { |
||||
var h [32]byte |
||||
if len(protocolName) <= 32 { |
||||
copy(h[:], protocolName) |
||||
} else { |
||||
h = getHash(protocolName, []byte{}) |
||||
} |
||||
return h |
||||
} |
||||
|
||||
func blake2HkdfInterface() hash.Hash { |
||||
h, _ := blake2s.New256([]byte{}) |
||||
return h |
||||
} |
||||
|
||||
func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) { |
||||
var k1 [32]byte |
||||
var k2 [32]byte |
||||
var k3 [32]byte |
||||
output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{}) |
||||
io.ReadFull(output, k1[:]) |
||||
io.ReadFull(output, k2[:]) |
||||
io.ReadFull(output, k3[:]) |
||||
return k1, k2, k3 |
||||
} |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* STATE MANAGEMENT * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
/* CipherState */ |
||||
func initializeKey(k [32]byte) cipherstate { |
||||
return cipherstate{k, minNonce} |
||||
} |
||||
|
||||
func hasKey(cs *cipherstate) bool { |
||||
return !isEmptyKey(cs.k) |
||||
} |
||||
|
||||
func setNonce(cs *cipherstate, newNonce uint32) *cipherstate { |
||||
cs.n = newNonce |
||||
return cs |
||||
} |
||||
|
||||
func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) { |
||||
e := encrypt(cs.k, cs.n, ad, plaintext) |
||||
cs = setNonce(cs, incrementNonce(cs.n)) |
||||
return cs, e |
||||
} |
||||
|
||||
func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) { |
||||
valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext) |
||||
cs = setNonce(cs, incrementNonce(cs.n)) |
||||
return cs, plaintext, valid |
||||
} |
||||
|
||||
func reKey(cs *cipherstate) *cipherstate { |
||||
e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:]) |
||||
copy(cs.k[:], e) |
||||
return cs |
||||
} |
||||
|
||||
/* SymmetricState */ |
||||
|
||||
func initializeSymmetric(protocolName []byte) symmetricstate { |
||||
h := hashProtocolName(protocolName) |
||||
ck := h |
||||
cs := initializeKey(emptyKey) |
||||
return symmetricstate{cs, ck, h} |
||||
} |
||||
|
||||
func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate { |
||||
ck, tempK, _ := getHkdf(ss.ck, ikm[:]) |
||||
ss.cs = initializeKey(tempK) |
||||
ss.ck = ck |
||||
return ss |
||||
} |
||||
|
||||
func mixHash(ss *symmetricstate, data []byte) *symmetricstate { |
||||
ss.h = getHash(ss.h[:], data) |
||||
return ss |
||||
} |
||||
|
||||
func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate { |
||||
var tempH [32]byte |
||||
var tempK [32]byte |
||||
ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:]) |
||||
ss = mixHash(ss, tempH[:]) |
||||
ss.cs = initializeKey(tempK) |
||||
return ss |
||||
} |
||||
|
||||
func getHandshakeHash(ss *symmetricstate) [32]byte { |
||||
return ss.h |
||||
} |
||||
|
||||
func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) { |
||||
var ciphertext []byte |
||||
if hasKey(&ss.cs) { |
||||
_, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext) |
||||
} else { |
||||
ciphertext = plaintext |
||||
} |
||||
ss = mixHash(ss, ciphertext) |
||||
return ss, ciphertext |
||||
} |
||||
|
||||
func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) { |
||||
var plaintext []byte |
||||
var valid bool |
||||
if hasKey(&ss.cs) { |
||||
_, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext) |
||||
} else { |
||||
plaintext, valid = ciphertext, true |
||||
} |
||||
ss = mixHash(ss, ciphertext) |
||||
return ss, plaintext, valid |
||||
} |
||||
|
||||
func split(ss *symmetricstate) (cipherstate, cipherstate) { |
||||
tempK1, tempK2, _ := getHkdf(ss.ck, []byte{}) |
||||
cs1 := initializeKey(tempK1) |
||||
cs2 := initializeKey(tempK2) |
||||
return cs1, cs2 |
||||
} |
||||
|
||||
/* HandshakeState */ |
||||
|
||||
func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate { |
||||
var ss symmetricstate |
||||
var e keypair |
||||
var re [32]byte |
||||
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s") |
||||
ss = initializeSymmetric(name) |
||||
mixHash(&ss, prologue) |
||||
mixHash(&ss, rs[:]) |
||||
return handshakestate{ss, s, e, rs, re, psk} |
||||
} |
||||
|
||||
func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate { |
||||
var ss symmetricstate |
||||
var e keypair |
||||
var re [32]byte |
||||
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s") |
||||
ss = initializeSymmetric(name) |
||||
mixHash(&ss, prologue) |
||||
mixHash(&ss, s.public_key[:]) |
||||
return handshakestate{ss, s, e, rs, re, psk} |
||||
} |
||||
|
||||
func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) { |
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{} |
||||
hs.e = generateKeypair() |
||||
ne = hs.e.public_key |
||||
mixHash(&hs.ss, ne[:]) |
||||
/* No PSK, so skipping mixKey */ |
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) |
||||
spk := make([]byte, len(hs.s.public_key)) |
||||
copy(spk[:], hs.s.public_key[:]) |
||||
_, ns = encryptAndHash(&hs.ss, spk) |
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) |
||||
_, ciphertext = encryptAndHash(&hs.ss, payload) |
||||
messageBuffer := messagebuffer{ne, ns, ciphertext} |
||||
return hs, messageBuffer |
||||
} |
||||
|
||||
func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) { |
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{} |
||||
hs.e = generateKeypair() |
||||
ne = hs.e.public_key |
||||
mixHash(&hs.ss, ne[:]) |
||||
/* No PSK, so skipping mixKey */ |
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) |
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) |
||||
_, ciphertext = encryptAndHash(&hs.ss, payload) |
||||
messageBuffer := messagebuffer{ne, ns, ciphertext} |
||||
cs1, cs2 := split(&hs.ss) |
||||
return hs.ss.h, messageBuffer, cs1, cs2 |
||||
} |
||||
|
||||
func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) { |
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{} |
||||
cs, ciphertext = encryptWithAd(cs, []byte{}, payload) |
||||
messageBuffer := messagebuffer{ne, ns, ciphertext} |
||||
return cs, messageBuffer |
||||
} |
||||
|
||||
func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) { |
||||
valid1 := true |
||||
if validatePublicKey(message.ne[:]) { |
||||
hs.re = message.ne |
||||
} |
||||
mixHash(&hs.ss, hs.re[:]) |
||||
/* No PSK, so skipping mixKey */ |
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) |
||||
_, ns, valid1 := decryptAndHash(&hs.ss, message.ns) |
||||
if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) { |
||||
copy(hs.rs[:], ns) |
||||
} |
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) |
||||
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) |
||||
return hs, plaintext, (valid1 && valid2) |
||||
} |
||||
|
||||
func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) { |
||||
valid1 := true |
||||
if validatePublicKey(message.ne[:]) { |
||||
hs.re = message.ne |
||||
} |
||||
mixHash(&hs.ss, hs.re[:]) |
||||
/* No PSK, so skipping mixKey */ |
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) |
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) |
||||
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) |
||||
cs1, cs2 := split(&hs.ss) |
||||
return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2 |
||||
} |
||||
|
||||
func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) { |
||||
/* No encrypted keys */ |
||||
_, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext) |
||||
return cs, plaintext, valid2 |
||||
} |
||||
|
||||
/* ---------------------------------------------------------------- * |
||||
* PROCESSES * |
||||
* ---------------------------------------------------------------- */ |
||||
|
||||
func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession { |
||||
var session noisesession |
||||
psk := emptyKey |
||||
if initiator { |
||||
session.hs = initializeInitiator(prologue, s, rs, psk) |
||||
} else { |
||||
session.hs = initializeResponder(prologue, s, rs, psk) |
||||
} |
||||
session.i = initiator |
||||
session.mc = 0 |
||||
return session |
||||
} |
||||
|
||||
func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) { |
||||
var messageBuffer messagebuffer |
||||
if session.mc == 0 { |
||||
_, messageBuffer = writeMessageA(&session.hs, message) |
||||
} |
||||
if session.mc == 1 { |
||||
session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message) |
||||
session.hs = handshakestate{} |
||||
} |
||||
if session.mc > 1 { |
||||
if session.i { |
||||
_, messageBuffer = writeMessageRegular(&session.cs1, message) |
||||
} else { |
||||
_, messageBuffer = writeMessageRegular(&session.cs2, message) |
||||
} |
||||
} |
||||
session.mc = session.mc + 1 |
||||
return session, messageBuffer |
||||
} |
||||
|
||||
func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) { |
||||
var plaintext []byte |
||||
var valid bool |
||||
if session.mc == 0 { |
||||
_, plaintext, valid = readMessageA(&session.hs, message) |
||||
} |
||||
if session.mc == 1 { |
||||
session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message) |
||||
session.hs = handshakestate{} |
||||
} |
||||
if session.mc > 1 { |
||||
if session.i { |
||||
_, plaintext, valid = readMessageRegular(&session.cs2, message) |
||||
} else { |
||||
_, plaintext, valid = readMessageRegular(&session.cs1, message) |
||||
} |
||||
} |
||||
session.mc = session.mc + 1 |
||||
return session, plaintext, valid |
||||
} |
||||
|
||||
func main() {} |
||||
Loading…
Reference in new issue