You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
4.3 KiB
149 lines
4.3 KiB
// Copyright (c) Tailscale Inc & contributors
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package taildrop
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
blockSize = int64(64 << 10)
|
|
hashAlgorithm = "sha256"
|
|
)
|
|
|
|
// blockChecksum represents the checksum for a single block.
|
|
type blockChecksum struct {
|
|
Checksum checksum `json:"checksum"`
|
|
Algorithm string `json:"algo"` // always "sha256" for now
|
|
Size int64 `json:"size"` // always (64<<10) for now
|
|
}
|
|
|
|
// checksum is an opaque checksum that is comparable.
|
|
type checksum struct{ cs [sha256.Size]byte }
|
|
|
|
func hash(b []byte) checksum {
|
|
return checksum{sha256.Sum256(b)}
|
|
}
|
|
func (cs checksum) String() string {
|
|
return hex.EncodeToString(cs.cs[:])
|
|
}
|
|
func (cs checksum) AppendText(b []byte) ([]byte, error) {
|
|
return hex.AppendEncode(b, cs.cs[:]), nil
|
|
}
|
|
func (cs checksum) MarshalText() ([]byte, error) {
|
|
return hex.AppendEncode(nil, cs.cs[:]), nil
|
|
}
|
|
func (cs *checksum) UnmarshalText(b []byte) error {
|
|
if len(b) != 2*len(cs.cs) {
|
|
return fmt.Errorf("invalid hex length: %d", len(b))
|
|
}
|
|
_, err := hex.Decode(cs.cs[:], b)
|
|
return err
|
|
}
|
|
|
|
// PartialFiles returns a list of partial files in [Handler.Dir]
|
|
// that were sent (or is actively being sent) by the provided id.
|
|
func (m *manager) PartialFiles(id clientID) ([]string, error) {
|
|
if m == nil || m.opts.fileOps == nil {
|
|
return nil, ErrNoTaildrop
|
|
}
|
|
suffix := id.partialSuffix()
|
|
files, err := m.opts.fileOps.ListFiles()
|
|
if err != nil {
|
|
return nil, redactError(err)
|
|
}
|
|
var ret []string
|
|
for _, filename := range files {
|
|
if strings.HasSuffix(filename, suffix) {
|
|
ret = append(ret, filename)
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// HashPartialFile returns a function that hashes the next block in the file,
|
|
// starting from the beginning of the file.
|
|
// It returns (BlockChecksum{}, io.EOF) when the stream is complete.
|
|
// It is the caller's responsibility to call close.
|
|
func (m *manager) HashPartialFile(id clientID, baseName string) (next func() (blockChecksum, error), close func() error, err error) {
|
|
if m == nil || m.opts.fileOps == nil {
|
|
return nil, nil, ErrNoTaildrop
|
|
}
|
|
noopNext := func() (blockChecksum, error) { return blockChecksum{}, io.EOF }
|
|
noopClose := func() error { return nil }
|
|
|
|
f, err := m.opts.fileOps.OpenReader(baseName + id.partialSuffix())
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return noopNext, noopClose, nil
|
|
}
|
|
return nil, nil, redactError(err)
|
|
}
|
|
|
|
b := make([]byte, blockSize) // TODO: Pool this?
|
|
next = func() (blockChecksum, error) {
|
|
switch n, err := io.ReadFull(f, b); {
|
|
case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF:
|
|
return blockChecksum{}, redactError(err)
|
|
case n == 0:
|
|
return blockChecksum{}, io.EOF
|
|
default:
|
|
return blockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil
|
|
}
|
|
}
|
|
close = f.Close
|
|
return next, close, nil
|
|
}
|
|
|
|
// resumeReader reads and discards the leading content of r
|
|
// that matches the content based on the checksums that exist.
|
|
// It returns the number of bytes consumed,
|
|
// and returns an [io.Reader] representing the remaining content.
|
|
func resumeReader(r io.Reader, hashNext func() (blockChecksum, error)) (int64, io.Reader, error) {
|
|
if hashNext == nil {
|
|
return 0, r, nil
|
|
}
|
|
|
|
var offset int64
|
|
b := make([]byte, 0, blockSize)
|
|
for {
|
|
// Obtain the next block checksum from the remote peer.
|
|
cs, err := hashNext()
|
|
switch {
|
|
case err == io.EOF:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
|
case err != nil:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
|
case cs.Algorithm != hashAlgorithm || cs.Size < 0 || cs.Size > blockSize:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm")
|
|
}
|
|
|
|
// Read the contents of the next block.
|
|
n, err := io.ReadFull(r, b[:cs.Size])
|
|
b = b[:n]
|
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
|
err = nil
|
|
}
|
|
if len(b) == 0 || err != nil {
|
|
// This should not occur in practice.
|
|
// It implies that an error occurred reading r,
|
|
// or that the partial file on the remote side is fully complete.
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
|
}
|
|
|
|
// Compare the local and remote block checksums.
|
|
// If it mismatches, then resume from this point.
|
|
if cs.Checksum != hash(b) {
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
|
}
|
|
offset += int64(len(b))
|
|
b = b[:0]
|
|
}
|
|
}
|
|
|