Skip to content

Commit 3b47d08

Browse files
authored
Freedom: Cache UDP resolve result (#4804)
1 parent 7f23a1c commit 3b47d08

File tree

2 files changed

+157
-5
lines changed

2 files changed

+157
-5
lines changed

common/utils/typed_sync_map.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package utils
2+
3+
import (
4+
"sync"
5+
)
6+
7+
// TypedSyncMap is a wrapper of sync.Map that provides type-safe for keys and values.
8+
// No need to use type assertions every time, so you can have more time to enjoy other things like GochiUsa
9+
// If sync.Map methods returned nil, it will return the zero value of the type V.
10+
type TypedSyncMap[K, V any] struct {
11+
syncMap *sync.Map
12+
}
13+
14+
// NewTypedSyncMap creates a new TypedSyncMap
15+
// K is key type, V is value type
16+
// It is recommended to use pointer types for V because sync.Map might return nil
17+
// If sync.Map methods really returned nil, it will return the zero value of the type V
18+
func NewTypedSyncMap[K any, V any]() *TypedSyncMap[K, V] {
19+
return &TypedSyncMap[K, V]{
20+
syncMap: &sync.Map{},
21+
}
22+
}
23+
24+
// Clear deletes all the entries, resulting in an empty Map.
25+
func (m *TypedSyncMap[K, V]) Clear() {
26+
m.syncMap.Clear()
27+
}
28+
29+
// CompareAndDelete deletes the entry for key if its value is equal to old.
30+
// The old value must be of a comparable type.
31+
//
32+
// If there is no current value for key in the map, CompareAndDelete
33+
// returns false (even if the old value is the nil interface value).
34+
func (m *TypedSyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
35+
return m.syncMap.CompareAndDelete(key, old)
36+
}
37+
38+
// CompareAndSwap swaps the old and new values for key
39+
// if the value stored in the map is equal to old.
40+
// The old value must be of a comparable type.
41+
func (m *TypedSyncMap[K, V]) CompareAndSwap(key K, old V, new V) (swapped bool) {
42+
return m.syncMap.CompareAndSwap(key, old, new)
43+
}
44+
45+
// Delete deletes the value for a key.
46+
func (m *TypedSyncMap[K, V]) Delete(key K) {
47+
m.syncMap.Delete(key)
48+
}
49+
50+
// Load returns the value stored in the map for a key, or nil if no
51+
// value is present.
52+
// The ok result indicates whether value was found in the map.
53+
func (m *TypedSyncMap[K, V]) Load(key K) (value V, ok bool) {
54+
anyValue, ok := m.syncMap.Load(key)
55+
// anyValue might be nil
56+
if anyValue != nil {
57+
value = anyValue.(V)
58+
}
59+
return value, ok
60+
}
61+
62+
// LoadAndDelete deletes the value for a key, returning the previous value if any.
63+
// The loaded result reports whether the key was present.
64+
func (m *TypedSyncMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
65+
anyValue, loaded := m.syncMap.LoadAndDelete(key)
66+
if anyValue != nil {
67+
value = anyValue.(V)
68+
}
69+
return value, loaded
70+
}
71+
72+
// LoadOrStore returns the existing value for the key if present.
73+
// Otherwise, it stores and returns the given value.
74+
// The loaded result is true if the value was loaded, false if stored.
75+
func (m *TypedSyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
76+
anyActual, loaded := m.syncMap.LoadOrStore(key, value)
77+
if anyActual != nil {
78+
actual = anyActual.(V)
79+
}
80+
return actual, loaded
81+
}
82+
83+
// Range calls f sequentially for each key and value present in the map.
84+
// If f returns false, range stops the iteration.
85+
//
86+
// Range does not necessarily correspond to any consistent snapshot of the Map's
87+
// contents: no key will be visited more than once, but if the value for any key
88+
// is stored or deleted concurrently (including by f), Range may reflect any
89+
// mapping for that key from any point during the Range call. Range does not
90+
// block other methods on the receiver; even f itself may call any method on m.
91+
//
92+
// Range may be O(N) with the number of elements in the map even if f returns
93+
// false after a constant number of calls.
94+
func (m *TypedSyncMap[K, V]) Range(f func(key K, value V) bool) {
95+
m.syncMap.Range(func(key, value any) bool {
96+
return f(key.(K), value.(V))
97+
})
98+
}
99+
100+
// Store sets the value for a key.
101+
func (m *TypedSyncMap[K, V]) Store(key K, value V) {
102+
m.syncMap.Store(key, value)
103+
}
104+
105+
// Swap swaps the value for a key and returns the previous value if any. The loaded result reports whether the key was present.
106+
func (m *TypedSyncMap[K, V]) Swap(key K, value V) (previous V, loaded bool) {
107+
anyPrevious, loaded := m.syncMap.Swap(key, value)
108+
if anyPrevious != nil {
109+
previous = anyPrevious.(V)
110+
}
111+
return previous, loaded
112+
}

proxy/freedom/freedom.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/xtls/xray-core/common/session"
1919
"github.com/xtls/xray-core/common/signal"
2020
"github.com/xtls/xray-core/common/task"
21+
"github.com/xtls/xray-core/common/utils"
2122
"github.com/xtls/xray-core/core"
2223
"github.com/xtls/xray-core/features/dns"
2324
"github.com/xtls/xray-core/features/policy"
@@ -202,7 +203,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
202203
writer = buf.NewWriter(conn)
203204
}
204205
} else {
205-
writer = NewPacketWriter(conn, h, ctx, UDPOverride)
206+
writer = NewPacketWriter(conn, h, ctx, UDPOverride, destination)
206207
if h.config.Noises != nil {
207208
errors.LogDebug(ctx, "NOISE", h.config.Noises)
208209
writer = &NoisePacketWriter{
@@ -317,7 +318,8 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
317318
return buf.MultiBuffer{b}, nil
318319
}
319320

320-
func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
321+
// DialDest means the dial target used in the dialer when creating conn
322+
func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination, DialDest net.Destination) buf.Writer {
321323
iConn := conn
322324
statConn, ok := iConn.(*stat.CounterConnection)
323325
if ok {
@@ -328,12 +330,20 @@ func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride
328330
counter = statConn.WriteCounter
329331
}
330332
if c, ok := iConn.(*internet.PacketConnWrapper); ok {
333+
// If DialDest is a domain, it will be resolved in dialer
334+
// check this behavior and add it to map
335+
resolvedUDPAddr := utils.NewTypedSyncMap[string, net.Address]()
336+
if DialDest.Address.Family().IsDomain() {
337+
RemoteAddress, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
338+
resolvedUDPAddr.Store(DialDest.Address.String(), net.ParseAddress(RemoteAddress))
339+
}
331340
return &PacketWriter{
332341
PacketConnWrapper: c,
333342
Counter: counter,
334343
Handler: h,
335344
Context: ctx,
336345
UDPOverride: UDPOverride,
346+
resolvedUDPAddr: resolvedUDPAddr,
337347
}
338348

339349
}
@@ -346,6 +356,12 @@ type PacketWriter struct {
346356
*Handler
347357
context.Context
348358
UDPOverride net.Destination
359+
360+
// Dest of udp packets might be a domain, we will resolve them to IP
361+
// But resolver will return a random one if the domain has many IPs
362+
// Resulting in these packets being sent to many different IPs randomly
363+
// So, cache and keep the resolve result
364+
resolvedUDPAddr *utils.TypedSyncMap[string, net.Address]
349365
}
350366

351367
func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
@@ -364,10 +380,34 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
364380
if w.UDPOverride.Port != 0 {
365381
b.UDP.Port = w.UDPOverride.Port
366382
}
367-
if w.Handler.config.hasStrategy() && b.UDP.Address.Family().IsDomain() {
368-
ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
369-
if ip != nil {
383+
if b.UDP.Address.Family().IsDomain() {
384+
if ip, ok := w.resolvedUDPAddr.Load(b.UDP.Address.Domain()); ok {
370385
b.UDP.Address = ip
386+
} else {
387+
ShouldUseSystemResolver := true
388+
if w.Handler.config.hasStrategy() {
389+
ip = w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
390+
if ip != nil {
391+
ShouldUseSystemResolver = false
392+
}
393+
// drop packet if resolve failed when forceIP
394+
if ip == nil && w.Handler.config.forceIP() {
395+
b.Release()
396+
continue
397+
}
398+
}
399+
if ShouldUseSystemResolver {
400+
udpAddr, err := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
401+
if err != nil {
402+
b.Release()
403+
continue
404+
} else {
405+
ip = net.IPAddress(udpAddr.IP)
406+
}
407+
}
408+
if ip != nil {
409+
b.UDP.Address, _ = w.resolvedUDPAddr.LoadOrStore(b.UDP.Address.Domain(), ip)
410+
}
371411
}
372412
}
373413
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())

0 commit comments

Comments
 (0)