Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"io"
"net"
"syscall"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
Expand Down Expand Up @@ -59,14 +58,10 @@ func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBuf
}

func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64, batchSize int) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
var handled bool
handled, n, err = copyDirect(source, destination, readCounters, writeCounters)
if handled {
return
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter, batchSize)
}
Expand Down
13 changes: 6 additions & 7 deletions common/bufio/copy_direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@ package bufio
import (
"errors"
"io"
"syscall"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
func copyDirect(source io.Reader, destination io.Writer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) {
return
}
rawDestination, err := destination.SyscallConn()
if err != nil {
sourceReader, sourceConn := N.SyscallConnForRead(source)
destinationWriter, destinationConn := N.SyscallConnForWrite(destination)
if sourceConn == nil || destinationConn == nil {
return
}
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
handed, n, err = splice(sourceConn, sourceReader, destinationConn, destinationWriter, readCounters, writeCounters)
return
}

Expand Down
47 changes: 32 additions & 15 deletions common/bufio/splice_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

const maxSpliceSize = 1 << 20

func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
handed = true
var pipeFDs [2]int
err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK)
Expand All @@ -20,12 +20,14 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []
}
defer unix.Close(pipeFDs[0])
defer unix.Close(pipeFDs[1])

_, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize)
var readN int
var readErr error
var writeSize int
var writeErr error
var (
readN int
readErr error
writeSize int
writeErr error
notFirstTime bool
)
readFunc := func(fd uintptr) (done bool) {
p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK)
readN = int(p0)
Expand All @@ -46,34 +48,49 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []
}
for {
err = source.Read(readFunc)
if err != nil {
readErr = err
}
if readErr != nil {
if readErr == unix.EINVAL || readErr == unix.ENOSYS {
err = readErr
}
if err != nil {
if sourceReader != nil {
newBuffer, newErr := sourceReader.HandleSyscallReadError(err)
if newErr != nil {
err = newErr
} else {
err = nil
if len(newBuffer) > 0 {
readN, readErr = unix.Write(pipeFDs[1], newBuffer)
if readErr != nil {
err = E.Cause(err, "write handled data")
}
}
}
} else if !notFirstTime && E.IsMulti(err, unix.EINVAL, unix.ENOSYS) {
handed = false
return
}
err = E.Cause(readErr, "splice read")
err = E.Cause(err, "splice read")
return
}
if readN == 0 {
return
}
writeSize = readN
err = destination.Write(writeFunc)
if err != nil {
writeErr = err
}
if writeErr != nil {
err = E.Cause(writeErr, "splice write")
err = writeErr
}
if err != nil {
err = E.Cause(err, "splice write")
return
}
n += int64(readN)
for _, readCounter := range readCounters {
readCounter(int64(readN))
}
for _, writeCounter := range writeCounters {
writeCounter(int64(readN))
}
notFirstTime = true
}
}
2 changes: 1 addition & 1 deletion common/bufio/splice_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ import (
N "github.com/sagernet/sing/common/network"
)

func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
return
}
31 changes: 28 additions & 3 deletions common/bufio/wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package bufio
import (
"io"

"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network"
)

func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) {
reader = N.UnwrapReader(reader)
if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter {
return readWaiter, true
}
Expand All @@ -17,11 +17,19 @@ func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) {
if readWaiter, created := createSyscallReadWaiter(reader); created {
return readWaiter, true
}
if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return nil, false
}
if u, ok := reader.(N.WithUpstreamReader); ok {
return CreateReadWaiter(u.UpstreamReader().(io.Reader))
}
if u, ok := reader.(common.WithUpstream); ok {
return CreateReadWaiter(u.Upstream().(io.Reader))
}
return nil, false
}

func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool) {
reader = N.UnwrapReader(reader)
if vectorisedReadWaiter, isVectorised := reader.(N.VectorisedReadWaiter); isVectorised {
return vectorisedReadWaiter, true
}
Expand All @@ -31,11 +39,19 @@ func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool)
if vectorisedReadWaiter, created := createVectorisedSyscallReadWaiter(reader); created {
return vectorisedReadWaiter, true
}
if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return nil, false
}
if u, ok := reader.(N.WithUpstreamReader); ok {
return CreateVectorisedReadWaiter(u.UpstreamReader().(io.Reader))
}
if u, ok := reader.(common.WithUpstream); ok {
return CreateVectorisedReadWaiter(u.Upstream().(io.Reader))
}
return nil, false
}

func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) {
reader = N.UnwrapPacketReader(reader)
if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter {
return readWaiter, true
}
Expand All @@ -45,6 +61,15 @@ func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) {
if readWaiter, created := createSyscallPacketReadWaiter(reader); created {
return readWaiter, true
}
if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return nil, false
}
if u, ok := reader.(N.WithUpstreamReader); ok {
return CreatePacketReadWaiter(u.UpstreamReader().(N.PacketReader))
}
if u, ok := reader.(common.WithUpstream); ok {
return CreatePacketReadWaiter(u.Upstream().(N.PacketReader))
}
return nil, false
}

Expand Down
70 changes: 67 additions & 3 deletions common/metadata/domain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,70 @@
package metadata

import _ "unsafe" // for linkname
// IsDomainName checks if a string is a presentation-format domain name
// (currently restricted to hostname-compatible "preferred name" LDH labels and
// SRV-like "underscore labels"; see golang.org/issue/12421).
//
// This function was originally created here:
//
// https://cs.opensource.google/go/go/+/master:src/net/dnsclient.go;l=76-146;drc=05cbbf985fed823a174bf95cc78a7d44f948fdab
//
// and it's being copy-pasted in order to use the same functionality. In the original package,
// this is a private function that cannot be accessed externally
func IsDomainName(s string) bool {
// The root domain name is valid. See golang.org/issue/45715.
if s == "." {
return true
}

//go:linkname IsDomainName net.isDomainName
func IsDomainName(domain string) bool
// See RFC 1035, RFC 3696.
// Presentation format has dots before every label except the first, and the
// terminal empty label is optional here because we assume fully-qualified
// (absolute) input. We must therefore reserve space for the first and last
// labels' length octets in wire format, where they are necessary and the
// maximum total length is 255.
// So our _effective_ maximum is 253, but 254 is not rejected if the last
// character is a dot.
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}

last := byte('.')
nonNumeric := false // true once we've seen a letter or hyphen
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
// fine
partlen++
case c == '-':
// Byte before dash cannot be dot.
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
// Byte before dot cannot be dot, dash.
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}

return nonNumeric
}
54 changes: 51 additions & 3 deletions common/network/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package network

import (
"io"
"syscall"

"github.com/sagernet/sing/common"
)

type CountFunc func(n int64)
Expand All @@ -27,32 +30,65 @@ type PacketWriteCounter interface {
}

func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []CountFunc) {
reader = UnwrapReader(reader)
if counter, isCounter := reader.(ReadCounter); isCounter {
upstreamReader, upstreamCountFunc := counter.UnwrapReader()
countFunc = append(countFunc, upstreamCountFunc...)
return UnwrapCountReader(upstreamReader, countFunc)
}
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return reader, countFunc
}
switch u := reader.(type) {
case ReadWaiter, ReadWaitCreator, syscall.Conn, SyscallReader:
// In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter
return reader, countFunc
case WithUpstreamReader:
return UnwrapCountReader(u.UpstreamReader().(io.Reader), countFunc)
case common.WithUpstream:
return UnwrapCountReader(u.Upstream().(io.Reader), countFunc)
}
return reader, countFunc
}

func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []CountFunc) {
writer = UnwrapWriter(writer)
if counter, isCounter := writer.(WriteCounter); isCounter {
upstreamWriter, upstreamCountFunc := counter.UnwrapWriter()
countFunc = append(countFunc, upstreamCountFunc...)
return UnwrapCountWriter(upstreamWriter, countFunc)
}
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer, countFunc
}
switch u := writer.(type) {
case syscall.Conn, SyscallWriter:
// In our use cases, counters is always at the top, so we stop when we encounter syscall conn
return writer, countFunc
case WithUpstreamWriter:
return UnwrapCountWriter(u.UpstreamWriter().(io.Writer), countFunc)
case common.WithUpstream:
return UnwrapCountWriter(u.Upstream().(io.Writer), countFunc)
}
return writer, countFunc
}

func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (PacketReader, []CountFunc) {
reader = UnwrapPacketReader(reader)
if counter, isCounter := reader.(PacketReadCounter); isCounter {
upstreamReader, upstreamCountFunc := counter.UnwrapPacketReader()
countFunc = append(countFunc, upstreamCountFunc...)
return UnwrapCountPacketReader(upstreamReader, countFunc)
}
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return reader, countFunc
}
switch u := reader.(type) {
case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn:
// In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter
return reader, countFunc
case WithUpstreamReader:
return UnwrapCountPacketReader(u.UpstreamReader().(PacketReader), countFunc)
case common.WithUpstream:
return UnwrapCountPacketReader(u.Upstream().(PacketReader), countFunc)
}
return reader, countFunc
}

Expand All @@ -63,5 +99,17 @@ func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (Packet
countFunc = append(countFunc, upstreamCountFunc...)
return UnwrapCountPacketWriter(upstreamWriter, countFunc)
}
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer, countFunc
}
switch u := writer.(type) {
case syscall.Conn:
// In our use cases, counters is always at the top, so we stop when we encounter syscall conn
return writer, countFunc
case WithUpstreamWriter:
return UnwrapCountPacketWriter(u.UpstreamWriter().(PacketWriter), countFunc)
case common.WithUpstream:
return UnwrapCountPacketWriter(u.Upstream().(PacketWriter), countFunc)
}
return writer, countFunc
}
Loading