11package network
22
33import (
4+ "net"
5+
46 "github.com/sagernet/sing/common"
57 E "github.com/sagernet/sing/common/exceptions"
68)
@@ -13,17 +15,65 @@ type HandshakeSuccess interface {
1315 HandshakeSuccess () error
1416}
1517
16- func ReportHandshakeFailure (conn any , err error ) error {
17- if handshakeConn , isHandshakeConn := common.Cast [HandshakeFailure ](conn ); isHandshakeConn {
18+ type ConnHandshakeSuccess interface {
19+ ConnHandshakeSuccess (conn net.Conn ) error
20+ }
21+
22+ type PacketConnHandshakeSuccess interface {
23+ PacketConnHandshakeSuccess (conn net.PacketConn ) error
24+ }
25+
26+ func ReportHandshakeFailure (reporter any , err error ) error {
27+ if handshakeConn , isHandshakeConn := common.Cast [HandshakeFailure ](reporter ); isHandshakeConn {
1828 return E .Append (err , handshakeConn .HandshakeFailure (err ), func (err error ) error {
1929 return E .Cause (err , "write handshake failure" )
2030 })
2131 }
2232 return err
2333}
2434
25- func ReportHandshakeSuccess (conn any ) error {
26- if handshakeConn , isHandshakeConn := common.Cast [HandshakeSuccess ](conn ); isHandshakeConn {
35+ func CloseOnHandshakeFailure (reporter any , onClose CloseHandler , err error ) error {
36+ if handshakeConn , isHandshakeConn := common.Cast [HandshakeFailure ](reporter ); isHandshakeConn {
37+ err = E .Append (err , handshakeConn .HandshakeFailure (err ), func (err error ) error {
38+ return E .Cause (err , "write handshake failure" )
39+ })
40+ } else {
41+ if tcpConn , isTCPConn := common.Cast [interface {
42+ SetLinger (sec int ) error
43+ }](reporter ); isTCPConn {
44+ tcpConn .SetLinger (0 )
45+ }
46+ common .Close (reporter )
47+ }
48+ if onClose != nil {
49+ onClose (err )
50+ }
51+ return err
52+ }
53+
54+ // Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
55+ func ReportHandshakeSuccess (reporter any ) error {
56+ if handshakeConn , isHandshakeConn := common.Cast [HandshakeSuccess ](reporter ); isHandshakeConn {
57+ return handshakeConn .HandshakeSuccess ()
58+ }
59+ return nil
60+ }
61+
62+ func ReportConnHandshakeSuccess (reporter any , conn net.Conn ) error {
63+ if handshakeConn , isHandshakeConn := common.Cast [ConnHandshakeSuccess ](reporter ); isHandshakeConn {
64+ return handshakeConn .ConnHandshakeSuccess (conn )
65+ }
66+ if handshakeConn , isHandshakeConn := common.Cast [HandshakeSuccess ](reporter ); isHandshakeConn {
67+ return handshakeConn .HandshakeSuccess ()
68+ }
69+ return nil
70+ }
71+
72+ func ReportPacketConnHandshakeSuccess (reporter any , conn net.PacketConn ) error {
73+ if handshakeConn , isHandshakeConn := common.Cast [PacketConnHandshakeSuccess ](reporter ); isHandshakeConn {
74+ return handshakeConn .PacketConnHandshakeSuccess (conn )
75+ }
76+ if handshakeConn , isHandshakeConn := common.Cast [HandshakeSuccess ](reporter ); isHandshakeConn {
2777 return handshakeConn .HandshakeSuccess ()
2878 }
2979 return nil
0 commit comments