2626
2727import math
2828import random
29- from typing import Callable
29+ from typing import Callable , Optional
3030
3131import numpy as np
3232from scipy .stats import bootstrap
@@ -45,18 +45,27 @@ def mean_stderr(arr):
4545
4646
4747class _bootstrap_internal :
48- def __init__ (self , metric : Callable , number_draws : int ):
49- self .metric = metric
48+ def __init__ (self , number_draws : int , metric : Optional [Callable ] = None ):
5049 self .number_draws = number_draws
50+ self .metric = metric
5151
5252 def __call__ (self , cur_experiment ):
5353 # Creates number_draws samplings (with replacement) of the population by iterating on a given seed
5454 population , seed = cur_experiment
5555 rnd = random .Random ()
5656 rnd .seed (seed )
5757 samplings = []
58- for _ in range (self .number_draws ):
59- samplings .append (self .metric (rnd .choices (population , k = len (population ))))
58+ import multiprocessing as mp
59+
60+ with mp .Pool (mp .cpu_count ()) as pool :
61+ samplings = pool .starmap (
62+ self .metric ,
63+ tqdm (
64+ [(rnd .choices (population , k = len (population )),) for _ in range (self .number_draws )],
65+ total = self .number_draws ,
66+ desc = "Sampling bootstrap iterations" ,
67+ ),
68+ )
6069 return samplings
6170
6271
@@ -65,28 +74,15 @@ def bootstrap_stderr(metric: Callable, population: list, number_experiments: int
6574 by sampling said population for number_experiments and recomputing the metric on the
6675 different samplings.
6776 """
68- import multiprocessing as mp
69-
70- pool = mp .Pool (mp .cpu_count ())
71-
7277 res = []
7378 number_draws = min (1000 , number_experiments )
74- # We change the seed every 1000 re-samplings
75- # and do the experiment 1000 re-samplings at a time
7679 number_seeds = number_experiments // number_draws
7780
78- hlog (f"Bootstrapping { metric .__name__ } 's stderr." )
79- for cur_bootstrap in tqdm (
80- pool .imap (
81- _bootstrap_internal (metric = metric , number_draws = number_draws ),
82- ((population , seed ) for seed in range (number_seeds )),
83- ),
84- total = number_seeds ,
85- ):
81+ hlog (f"Bootstrapping { metric .__name__ } 's stderr with { number_seeds } seeds." )
82+ for seed in range (number_seeds ):
8683 # sample w replacement
87- res .extend (cur_bootstrap )
84+ res .extend (_bootstrap_internal ( metric = metric , number_draws = number_draws )(( population , seed )) )
8885
89- pool .close ()
9086 return mean_stderr (res )
9187
9288
0 commit comments