Skip to content

[client] Fix/nil relayed address #4153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions relay/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,17 @@ func (c *Client) Connect(ctx context.Context) error {
return nil
}

if err := c.connect(ctx); err != nil {
instanceURL, err := c.connect(ctx)
if err != nil {
return err
}
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()

c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)

c.log = c.log.WithField("relay", c.instanceURL.String())
c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established")

c.serviceIsRunning = true
Expand Down Expand Up @@ -229,9 +233,18 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro

c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
conn := NewConn(c, peerID, msgChannel, c.instanceURL)

c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}

c.muInstanceURL.Lock()
instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)

_, ok = c.conns[peerID]
if ok {
c.mu.Unlock()
Expand Down Expand Up @@ -278,69 +291,67 @@ func (c *Client) Close() error {
return c.close(true)
}

func (c *Client) connect(ctx context.Context) error {
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return err
return nil, err
}
c.relayConn = conn

if err = c.handShake(ctx); err != nil {
instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
}
return err
return nil, err
}

return nil
return instanceURL, nil
}

func (c *Client) handShake(ctx context.Context) error {
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err)
return err
return nil, err
}

_, err = c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send auth message: %s", err)
return err
return nil, err
}
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(ctx, buf)
if err != nil {
c.log.Errorf("failed to read auth response: %s", err)
return err
return nil, err
}

_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return fmt.Errorf("validate version: %w", err)
return nil, fmt.Errorf("validate version: %w", err)
}

msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
return err
return nil, err
}

if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
return nil, fmt.Errorf("unexpected message type")
}

addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return err
return nil, err
}

c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
return &RelayAddr{addr: addr}, nil
}

func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
Expand Down Expand Up @@ -386,10 +397,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal

hc.Stop()

c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()

c.stateSubscription.Cleanup()
c.wgReadLoop.Done()
_ = c.close(false)
Expand Down Expand Up @@ -578,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error {
c.log.Warn("relay connection was already marked as not running")
return nil
}

c.serviceIsRunning = false

c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()

c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
Expand Down
17 changes: 4 additions & 13 deletions relay/client/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,14 @@ func TestForeginAutoClose(t *testing.T) {
errChan := make(chan error, 1)
go func() {
t.Log("binding server 1.")
err := srv1.Listen(srvCfg1)
if err != nil {
if err := srv1.Listen(srvCfg1); err != nil {
errChan <- err
}
}()

defer func() {
t.Logf("closing server 1.")
err := srv1.Shutdown(ctx)
if err != nil {
if err := srv1.Shutdown(ctx); err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
Expand Down Expand Up @@ -287,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) {
}

t.Log("open connection to another peer")
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}

t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}

timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
Expand Down
59 changes: 41 additions & 18 deletions relay/client/peer_subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package client
import (
"context"
"errors"
"fmt"
"sync"
"time"

log "github.com/sirupsen/logrus"
Expand All @@ -28,6 +30,7 @@ type PeersStateSubscription struct {

listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
}

func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
Expand All @@ -43,75 +46,95 @@ func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offl
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
s.mu.Lock()
defer s.mu.Unlock()

for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue
}

close(waitCh)
waitCh <- struct{}{}
delete(s.waitingPeers, peerID)
close(waitCh)
}
}

func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
s.mu.Unlock()

if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}

// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
// todo: when we unsubscribe while this is running, this will not return with error
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online")
}

// Create a channel to wait for the peer to come online
waitCh := make(chan struct{})
waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()

if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil {
if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
close(waitCh)
delete(s.waitingPeers, peerID)
return err
}

defer func() {
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
}()
s.mu.Unlock()
return err
}

// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case <-waitCh:
case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}

s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err()
}
}

func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)

s.mu.Lock()
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
Expand All @@ -120,11 +143,15 @@ func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerI

delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()

return msgErr
}

func (s *PeersStateSubscription) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()

for _, waitCh := range s.waitingPeers {
close(waitCh)
}
Expand All @@ -133,16 +160,12 @@ func (s *PeersStateSubscription) Cleanup() {
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}

func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg(peerIDs)
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil {
return err
}

for _, peer := range peerIDs {
s.listenForOfflinePeers[peer] = struct{}{}
}

for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
Expand Down
Loading
Loading