1212from kafka .protocol import CODEC_NONE
1313
1414import threading
15- import multiprocessing as mp
1615try :
17- from queue import Empty
16+ from queue import Empty , Queue
1817except ImportError :
19- from Queue import Empty
18+ from Queue import Empty , Queue
2019
2120
2221class TestKafkaProducer (unittest .TestCase ):
@@ -56,33 +55,26 @@ def partitions(topic):
5655class TestKafkaProducerSendUpstream (unittest .TestCase ):
5756
5857 def setUp (self ):
59-
60- # create a multiprocessing Value to store call counter
61- # (magicmock counters don't work with other processes)
62- self .send_calls_count = mp .Value ('i' , 0 )
63-
64- def send_side_effect (* args , ** kwargs ):
65- self .send_calls_count .value += 1
66-
6758 self .client = MagicMock ()
68- self .client .send_produce_request .side_effect = send_side_effect
69- self .queue = mp .Queue ()
59+ self .queue = Queue ()
7060
7161 def _run_process (self , retries_limit = 3 , sleep_timeout = 1 ):
7262 # run _send_upstream process with the queue
73- self .process = mp .Process (
63+ stop_event = threading .Event ()
64+ self .thread = threading .Thread (
7465 target = _send_upstream ,
7566 args = (self .queue , self .client , CODEC_NONE ,
7667 0.3 , # batch time (seconds)
7768 3 , # batch length
7869 Producer .ACK_AFTER_LOCAL_WRITE ,
7970 Producer .DEFAULT_ACK_TIMEOUT ,
8071 50 , # retry backoff (ms)
81- retries_limit ))
82- self .process .daemon = True
83- self .process .start ()
72+ retries_limit ,
73+ stop_event ))
74+ self .thread .daemon = True
75+ self .thread .start ()
8476 time .sleep (sleep_timeout )
85- self . process . terminate ()
77+ stop_event . set ()
8678
8779 def test_wo_retries (self ):
8880
@@ -97,7 +89,8 @@ def test_wo_retries(self):
9789
9890 # there should be 4 non-void cals:
9991 # 3 batches of 3 msgs each + 1 batch of 1 message
100- self .assertEqual (self .send_calls_count .value , 4 )
92+ self .assertEqual (self .client .send_produce_request .call_count , 4 )
93+
10194
10295 def test_first_send_failed (self ):
10396
@@ -106,11 +99,10 @@ def test_first_send_failed(self):
10699 for i in range (10 ):
107100 self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
108101
109- is_first_time = mp . Value ( 'b' , True )
102+ self . client . is_first_time = True
110103 def send_side_effect (reqs , * args , ** kwargs ):
111- self .send_calls_count .value += 1
112- if is_first_time .value :
113- is_first_time .value = False
104+ if self .client .is_first_time :
105+ self .client .is_first_time = False
114106 raise FailedPayloadsError (reqs )
115107
116108 self .client .send_produce_request .side_effect = send_side_effect
@@ -122,7 +114,7 @@ def send_side_effect(reqs, *args, **kwargs):
122114
123115 # there should be 5 non-void cals: 1st failed batch of 3 msgs
124116 # + 3 batches of 3 msgs each + 1 batch of 1 msg = 1 + 3 + 1 = 5
125- self .assertEqual (self .send_calls_count . value , 5 )
117+ self .assertEqual (self .client . send_produce_request . call_count , 5 )
126118
127119 def test_with_limited_retries (self ):
128120
@@ -132,7 +124,6 @@ def test_with_limited_retries(self):
132124 self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
133125
134126 def send_side_effect (reqs , * args , ** kwargs ):
135- self .send_calls_count .value += 1
136127 raise FailedPayloadsError (reqs )
137128
138129 self .client .send_produce_request .side_effect = send_side_effect
@@ -145,8 +136,7 @@ def send_side_effect(reqs, *args, **kwargs):
145136 # there should be 16 non-void cals:
146137 # 3 initial batches of 3 msgs each + 1 initial batch of 1 msg +
147138 # 3 retries of the batches above = 4 + 3 * 4 = 16, all failed
148- self .assertEqual (self .send_calls_count .value , 16 )
149-
139+ self .assertEqual (self .client .send_produce_request .call_count , 16 )
150140
151141 def test_with_unlimited_retries (self ):
152142
@@ -156,7 +146,6 @@ def test_with_unlimited_retries(self):
156146 self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
157147
158148 def send_side_effect (reqs , * args , ** kwargs ):
159- self .send_calls_count .value += 1
160149 raise FailedPayloadsError (reqs )
161150
162151 self .client .send_produce_request .side_effect = send_side_effect
@@ -174,5 +163,5 @@ def send_side_effect(reqs, *args, **kwargs):
174163 self .assertEqual (self .queue .empty (), True )
175164
176165 # 1s / 50ms of backoff = 20 times max
177- self .assertTrue ( self . send_calls_count . value > 10 )
178- self .assertTrue (self . send_calls_count . value <= 20 )
166+ calls = self .client . send_produce_request . call_count
167+ self .assertTrue (calls > 10 & calls <= 20 )
0 commit comments