diff --git a/floodsub.go b/floodsub.go index 45b3fdee..4c943bba 100644 --- a/floodsub.go +++ b/floodsub.go @@ -71,7 +71,7 @@ func (fs *FloodSubRouter) AcceptFrom(peer.ID) AcceptStatus { return AcceptAll } -func (fs *FloodSubRouter) PreValidation([]*Message) {} +func (fs *FloodSubRouter) PreValidation(from peer.ID, msgs []*Message) {} func (fs *FloodSubRouter) HandleRPC(rpc *RPC) {} diff --git a/gossipsub.go b/gossipsub.go index 214696b7..ecd4edaa 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -707,7 +707,7 @@ func (gs *GossipSubRouter) AcceptFrom(p peer.ID) AcceptStatus { // PreValidation sends the IDONTWANT control messages to all the mesh // peers. They need to be sent right before the validation because they // should be seen by the peers as soon as possible. -func (gs *GossipSubRouter) PreValidation(msgs []*Message) { +func (gs *GossipSubRouter) PreValidation(from peer.ID, msgs []*Message) { tmids := make(map[string][]string) for _, msg := range msgs { if len(msg.GetData()) < gs.params.IDontWantMessageThreshold { @@ -724,6 +724,10 @@ func (gs *GossipSubRouter) PreValidation(msgs []*Message) { shuffleStrings(mids) // send IDONTWANT to all the mesh peers for p := range gs.mesh[topic] { + if p == from { + // We don't send IDONTWANT to the peer that sent us the messages + continue + } // send to only peers that support IDONTWANT if gs.feature(GossipSubFeatureIdontwant, gs.peers[p]) { idontwant := []*pb.ControlIDontWant{{MessageIDs: mids}} diff --git a/gossipsub_test.go b/gossipsub_test.go index 675d164c..abb347fd 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -2815,6 +2815,78 @@ func TestGossipsubIdontwantReceive(t *testing.T) { <-ctx.Done() } +type mockRawTracer struct { + onRecvRPC func(*RPC) +} + +func (m *mockRawTracer) RecvRPC(rpc *RPC) { + if m.onRecvRPC != nil { + m.onRecvRPC(rpc) + } +} + +func (m *mockRawTracer) AddPeer(p peer.ID, proto protocol.ID) {} +func (m *mockRawTracer) DeliverMessage(msg *Message) {} +func (m *mockRawTracer) DropRPC(rpc *RPC, p peer.ID) {} +func (m *mockRawTracer) DuplicateMessage(msg *Message) {} +func (m *mockRawTracer) Graft(p peer.ID, topic string) {} +func (m *mockRawTracer) Join(topic string) {} +func (m *mockRawTracer) Leave(topic string) {} +func (m *mockRawTracer) Prune(p peer.ID, topic string) {} +func (m *mockRawTracer) RejectMessage(msg *Message, reason string) {} +func (m *mockRawTracer) RemovePeer(p peer.ID) {} +func (m *mockRawTracer) SendRPC(rpc *RPC, p peer.ID) {} +func (m *mockRawTracer) ThrottlePeer(p peer.ID) {} +func (m *mockRawTracer) UndeliverableMessage(msg *Message) {} +func (m *mockRawTracer) ValidateMessage(msg *Message) {} + +var _ RawTracer = &mockRawTracer{} + +func TestGossipsubNoIDONTWANTToMessageSender(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getDefaultHosts(t, 3) + denseConnect(t, hosts) + + psubs := make([]*PubSub, 2) + + receivedIDONTWANT := make(chan struct{}) + psubs[0] = getGossipsub(ctx, hosts[0], WithRawTracer(&mockRawTracer{ + onRecvRPC: func(rpc *RPC) { + if len(rpc.GetControl().GetIdontwant()) > 0 { + close(receivedIDONTWANT) + } + }, + })) + psubs[1] = getGossipsub(ctx, hosts[1]) + + topicString := "foobar" + var topics []*Topic + for _, ps := range psubs { + topic, err := ps.Join(topicString) + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) + + _, err = ps.Subscribe(topicString) + if err != nil { + t.Fatal(err) + } + } + time.Sleep(time.Second) + + msg := make([]byte, GossipSubIDontWantMessageThreshold+1) + topics[0].Publish(ctx, msg) + + select { + case <-receivedIDONTWANT: + t.Fatal("IDONTWANT should not be sent to the message sender") + case <-time.After(time.Second): + } + +} + // Test that non-mesh peers will not get IDONTWANT func TestGossipsubIdontwantNonMesh(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/pubsub.go b/pubsub.go index 3ca14abb..5c27c3e9 100644 --- a/pubsub.go +++ b/pubsub.go @@ -203,7 +203,7 @@ type PubSubRouter interface { AcceptFrom(peer.ID) AcceptStatus // PreValidation is invoked on messages in the RPC envelope right before pushing it to // the validation pipeline - PreValidation([]*Message) + PreValidation(from peer.ID, msgs []*Message) // HandleRPC is invoked to process control messages in the RPC envelope. // It is invoked after subscriptions and payload messages have been processed. HandleRPC(*RPC) @@ -1106,7 +1106,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { toPush = append(toPush, msg) } } - p.rt.PreValidation(toPush) + p.rt.PreValidation(rpc.from, toPush) for _, msg := range toPush { p.pushMsg(msg) } diff --git a/randomsub.go b/randomsub.go index 4e410f5f..f9f64736 100644 --- a/randomsub.go +++ b/randomsub.go @@ -94,7 +94,7 @@ func (rs *RandomSubRouter) AcceptFrom(peer.ID) AcceptStatus { return AcceptAll } -func (rs *RandomSubRouter) PreValidation([]*Message) {} +func (rs *RandomSubRouter) PreValidation(from peer.ID, msgs []*Message) {} func (rs *RandomSubRouter) HandleRPC(rpc *RPC) {}