This will be used in a future change to do localhost connection authentication. This lets us quickly map a localhost TCP connection to a PID. (A future change will then map a pid to a user) TODO: pull portlist's netstat code into this package. Then portlist will be fast on Windows without requiring shelling out to netstat.exe.main
parent
8b60936913
commit
f65eb4e5c1
@ -0,0 +1,36 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package netstat returns the local machine's network connection table.
|
||||
package netstat |
||||
|
||||
import ( |
||||
"errors" |
||||
"runtime" |
||||
|
||||
"inet.af/netaddr" |
||||
) |
||||
|
||||
var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) |
||||
|
||||
type Entry struct { |
||||
Local, Remote netaddr.IPPort |
||||
Pid int |
||||
State string // TODO: type?
|
||||
} |
||||
|
||||
// Table contains local machine's TCP connection entries.
|
||||
//
|
||||
// Currently only TCP (IPv4 and IPv6) are included.
|
||||
type Table struct { |
||||
Entries []Entry |
||||
} |
||||
|
||||
// Get returns the connection table.
|
||||
//
|
||||
// It returns ErrNotImplemented if the table is not available for the
|
||||
// current operating system.
|
||||
func Get() (*Table, error) { |
||||
return get() |
||||
} |
||||
@ -0,0 +1,11 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !windows
|
||||
|
||||
package netstat |
||||
|
||||
func get() (*Table, error) { |
||||
return nil, ErrNotImplemented |
||||
} |
||||
@ -0,0 +1,22 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package netstat |
||||
|
||||
import ( |
||||
"testing" |
||||
) |
||||
|
||||
func TestGet(t *testing.T) { |
||||
nt, err := Get() |
||||
if err == ErrNotImplemented { |
||||
t.Skip("TODO: not implemented") |
||||
} |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
for _, e := range nt.Entries { |
||||
t.Logf("Entry: %+v", e) |
||||
} |
||||
} |
||||
@ -0,0 +1,178 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package netstat returns the local machine's network connection table.
|
||||
package netstat |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"syscall" |
||||
"unsafe" |
||||
|
||||
"golang.org/x/sys/windows" |
||||
"inet.af/netaddr" |
||||
) |
||||
|
||||
// See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
|
||||
|
||||
// TCP_TABLE_OWNER_PID_ALL means to include the PID info. The table type
|
||||
// we get back from Windows depends on AF_INET vs AF_INET6:
|
||||
// MIB_TCPTABLE_OWNER_PID for v4 or MIB_TCP6TABLE_OWNER_PID for v6.
|
||||
const tcpTableOwnerPidAll = 5 |
||||
|
||||
var ( |
||||
iphlpapi = syscall.NewLazyDLL("iphlpapi.dll") |
||||
getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable") |
||||
// TODO: GetExtendedUdpTable also? if/when needed.
|
||||
) |
||||
|
||||
type _MIB_TCPROW_OWNER_PID struct { |
||||
state uint32 |
||||
localAddr uint32 |
||||
localPort uint32 |
||||
remoteAddr uint32 |
||||
remotePort uint32 |
||||
pid uint32 |
||||
} |
||||
|
||||
type _MIB_TCP6ROW_OWNER_PID struct { |
||||
localAddr [16]byte |
||||
localScope uint32 |
||||
localPort uint32 |
||||
remoteAddr [16]byte |
||||
remoteScope uint32 |
||||
remotePort uint32 |
||||
state uint32 |
||||
pid uint32 |
||||
} |
||||
|
||||
func get() (*Table, error) { |
||||
t := new(Table) |
||||
if err := t.addEntries(windows.AF_INET); err != nil { |
||||
return nil, fmt.Errorf("failed to get IPv4 entries: %w", err) |
||||
} |
||||
if err := t.addEntries(windows.AF_INET6); err != nil { |
||||
return nil, fmt.Errorf("failed to get IPv6 entries: %w", err) |
||||
} |
||||
return t, nil |
||||
} |
||||
|
||||
func (t *Table) addEntries(fam int) error { |
||||
var size uint32 |
||||
var addr unsafe.Pointer |
||||
var buf []byte |
||||
for { |
||||
err, _, _ := getTCPTable.Call( |
||||
uintptr(addr), |
||||
uintptr(unsafe.Pointer(&size)), |
||||
1, // sorted
|
||||
uintptr(fam), |
||||
tcpTableOwnerPidAll, |
||||
0, // reserved; "must be zero"
|
||||
) |
||||
if err == 0 { |
||||
break |
||||
} |
||||
if err == uintptr(syscall.ERROR_INSUFFICIENT_BUFFER) { |
||||
const maxSize = 10 << 20 |
||||
if size > maxSize || size < 4 { |
||||
return fmt.Errorf("unreasonable kernel-reported size %d", size) |
||||
} |
||||
buf = make([]byte, size) |
||||
addr = unsafe.Pointer(&buf[0]) |
||||
continue |
||||
} |
||||
return syscall.Errno(err) |
||||
} |
||||
if len(buf) < int(size) { |
||||
return errors.New("unexpected size growth from system call") |
||||
} |
||||
buf = buf[:size] |
||||
|
||||
numEntries := *(*uint32)(unsafe.Pointer(&buf[0])) |
||||
buf = buf[4:] |
||||
|
||||
var recSize int |
||||
switch fam { |
||||
case windows.AF_INET: |
||||
recSize = 6 * 4 |
||||
case windows.AF_INET6: |
||||
recSize = 6*4 + 16*2 |
||||
} |
||||
dataLen := numEntries * uint32(recSize) |
||||
if uint32(len(buf)) > dataLen { |
||||
buf = buf[:dataLen] |
||||
} |
||||
for len(buf) >= recSize { |
||||
switch fam { |
||||
case windows.AF_INET: |
||||
row := (*_MIB_TCPROW_OWNER_PID)(unsafe.Pointer(&buf[0])) |
||||
t.Entries = append(t.Entries, Entry{ |
||||
Local: ipport4(row.localAddr, port(&row.localPort)), |
||||
Remote: ipport4(row.remoteAddr, port(&row.remotePort)), |
||||
Pid: int(row.pid), |
||||
State: state(row.state), |
||||
}) |
||||
case windows.AF_INET6: |
||||
row := (*_MIB_TCP6ROW_OWNER_PID)(unsafe.Pointer(&buf[0])) |
||||
t.Entries = append(t.Entries, Entry{ |
||||
Local: ipport6(row.localAddr, row.localScope, port(&row.localPort)), |
||||
Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)), |
||||
Pid: int(row.pid), |
||||
State: state(row.state), |
||||
}) |
||||
} |
||||
buf = buf[recSize:] |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
var states = []string{ |
||||
"", |
||||
"CLOSED", |
||||
"LISTEN", |
||||
"SYN-SENT", |
||||
"SYN-RECEIVED", |
||||
"ESTABLISHED", |
||||
"FIN-WAIT-1", |
||||
"FIN-WAIT-2", |
||||
"CLOSE-WAIT", |
||||
"CLOSING", |
||||
"LAST-ACK", |
||||
"DELETE-TCB", |
||||
} |
||||
|
||||
func state(v uint32) string { |
||||
if v < uint32(len(states)) { |
||||
return states[v] |
||||
} |
||||
return fmt.Sprintf("unknown-state-%d", v) |
||||
} |
||||
|
||||
func ipport4(addr uint32, port uint16) netaddr.IPPort { |
||||
a4 := (*[4]byte)(unsafe.Pointer(&addr)) |
||||
return netaddr.IPPort{ |
||||
IP: netaddr.IPv4(a4[0], a4[1], a4[2], a4[3]), |
||||
Port: port, |
||||
} |
||||
} |
||||
|
||||
func ipport6(addr [16]byte, scope uint32, port uint16) netaddr.IPPort { |
||||
ip := netaddr.IPFrom16(addr) |
||||
if scope != 0 { |
||||
// TODO: something better here?
|
||||
ip = ip.WithZone(fmt.Sprint(scope)) |
||||
} |
||||
return netaddr.IPPort{ |
||||
IP: ip, |
||||
Port: port, |
||||
} |
||||
} |
||||
|
||||
func port(v *uint32) uint16 { |
||||
p := (*[4]byte)(unsafe.Pointer(v)) |
||||
return binary.BigEndian.Uint16(p[:2]) |
||||
} |
||||
Loading…
Reference in new issue