@@ -479,11 +479,23 @@ def sample(
479479 model .check_start_vals (ip )
480480 _check_start_shape (model , ip )
481481
482+ # Create trace backends for each chain
483+ traces = [
484+ _init_trace (
485+ expected_length = draws + tune ,
486+ stats_dtypes = step .stats_dtypes ,
487+ chain_number = chain_number ,
488+ trace = trace ,
489+ model = model ,
490+ )
491+ for chain_number in range (chains )
492+ ]
493+
482494 sample_args = {
483495 "draws" : draws ,
484496 "step" : step ,
485497 "start" : initial_points ,
486- "trace " : trace ,
498+ "traces " : traces ,
487499 "chains" : chains ,
488500 "tune" : tune ,
489501 "progressbar" : progressbar ,
@@ -524,12 +536,7 @@ def sample(
524536 _log .info (f"Multiprocess sampling ({ chains } chains in { cores } jobs)" )
525537 _print_step_hierarchy (step )
526538 try :
527- traces = _mp_sample (** sample_args , ** parallel_args )
528- if discard_tuned_samples :
529- traces , length = _choose_chains (traces , tune )
530- else :
531- traces , length = _choose_chains (traces , 0 )
532- mtrace = MultiTrace (traces )[:length ]
539+ _mp_sample (** sample_args , ** parallel_args )
533540 except pickle .PickleError :
534541 _log .warning ("Could not pickle model, sampling singlethreaded." )
535542 _log .debug ("Pickling error:" , exc_info = True )
@@ -544,15 +551,21 @@ def sample(
544551 if has_population_samplers :
545552 _log .info (f"Population sampling ({ chains } chains)" )
546553 _print_step_hierarchy (step )
547- mtrace = _sample_population (
548- initial_points = initial_points , parallelize = cores > 1 , ** sample_args
549- )
554+ _sample_population (initial_points = initial_points , parallelize = cores > 1 , ** sample_args )
550555 else :
551556 _log .info (f"Sequential sampling ({ chains } chains in 1 job)" )
552557 _print_step_hierarchy (step )
553- mtrace = _sample_many (** sample_args )
558+ _sample_many (** sample_args )
554559
555560 t_sampling = time .time () - t_start
561+
562+ # Wrap chain traces in a MultiTrace
563+ if discard_tuned_samples :
564+ traces , length = _choose_chains (traces , tune )
565+ else :
566+ traces , length = _choose_chains (traces , 0 )
567+ mtrace = MultiTrace (traces )[:length ]
568+
556569 # count the number of tune/draw iterations that happened
557570 # ideally via the "tune" statistic, but not all samplers record it!
558571 if "tune" in mtrace .stat_names :
@@ -639,12 +652,13 @@ def _sample_many(
639652 * ,
640653 draws : int ,
641654 chains : int ,
655+ traces : Sequence [BaseTrace ],
642656 start : Sequence [PointType ],
643657 random_seed : Optional [Sequence [RandomSeed ]],
644658 step ,
645659 callback = None ,
646660 ** kwargs ,
647- ) -> MultiTrace :
661+ ):
648662 """Samples all chains sequentially.
649663
650664 Parameters
@@ -659,35 +673,19 @@ def _sample_many(
659673 A list of seeds, one for each chain
660674 step: function
661675 Step function
662-
663- Returns
664- -------
665- mtrace: MultiTrace
666- Contains samples of all chains
667676 """
668- traces : List [BaseTrace ] = []
669677 for i in range (chains ):
670- trace = _sample (
678+ _sample (
671679 draws = draws ,
672680 chain = i ,
673681 start = start [i ],
674682 step = step ,
683+ trace = traces [i ],
675684 random_seed = None if random_seed is None else random_seed [i ],
676685 callback = callback ,
677686 ** kwargs ,
678687 )
679- if trace is None :
680- if len (traces ) == 0 :
681- raise ValueError ("Sampling stopped before a sample was created." )
682- else :
683- break
684- elif len (trace ) < draws :
685- if len (traces ) == 0 :
686- traces .append (trace )
687- break
688- else :
689- traces .append (trace )
690- return MultiTrace (traces )
688+ return
691689
692690
693691def _sample (
@@ -698,12 +696,12 @@ def _sample(
698696 start : PointType ,
699697 draws : int ,
700698 step = None ,
701- trace : Optional [ BaseTrace ] = None ,
699+ trace : BaseTrace ,
702700 tune : int ,
703701 model : Optional [Model ] = None ,
704702 callback = None ,
705703 ** kwargs ,
706- ) -> BaseTrace :
704+ ) -> None :
707705 """Main iteration for singleprocess sampling.
708706
709707 Multiple step methods are supported via compound step methods.
@@ -724,16 +722,10 @@ def _sample(
724722 step : function
725723 Step function
726724 trace : backend, optional
727- A backend instance or None.
728- If None, the NDArray backend is used.
725+ A backend instance.
729726 tune : int
730727 Number of iterations to tune.
731728 model : Model (optional if in ``with`` context)
732-
733- Returns
734- -------
735- strace : BaseTrace
736- A ``BaseTrace`` object that contains the samples for this chain.
737729 """
738730 skip_first = kwargs .get ("skip_first" , 0 )
739731
@@ -756,31 +748,27 @@ def _sample(
756748 else :
757749 sampling = sampling_gen
758750 try :
759- strace = None
760- for it , (strace , diverging ) in enumerate (sampling ):
751+ for it , diverging in enumerate (sampling ):
761752 if it >= skip_first and diverging :
762753 _pbar_data ["divergences" ] += 1
763754 if progressbar :
764755 sampling .comment = _desc .format (** _pbar_data )
765756 except KeyboardInterrupt :
766757 pass
767- if strace is None :
768- raise Exception ("KeyboardInterrupt happened before the base trace was created." )
769- return strace
770758
771759
772760def _iter_sample (
773761 * ,
774762 draws : int ,
775763 step ,
776764 start : PointType ,
777- trace : Optional [ BaseTrace ] = None ,
765+ trace : BaseTrace ,
778766 chain : int = 0 ,
779767 tune : int = 0 ,
780768 model = None ,
781769 random_seed : RandomSeed = None ,
782770 callback = None ,
783- ) -> Iterator [Tuple [ BaseTrace , bool ] ]:
771+ ) -> Iterator [bool ]:
784772 """Generator for sampling one chain. (Used in singleprocess sampling.)
785773
786774 Parameters
@@ -792,9 +780,8 @@ def _iter_sample(
792780 start : dict
793781 Starting point in parameter space (or partial point).
794782 Must contain numeric (transformed) initial values for all (transformed) free variables.
795- trace : backend, optional
796- A backend instance or None.
797- If None, the NDArray backend is used.
783+ trace : backend
784+ A backend instance.
798785 chain : int, optional
799786 Chain number used to store sample in backend.
800787 tune : int, optional
@@ -804,8 +791,6 @@ def _iter_sample(
804791
805792 Yields
806793 ------
807- strace : BaseTrace
808- The trace object containing the samples for this chain
809794 diverging : bool
810795 Indicates if the draw is divergent. Only available with some samplers.
811796 """
@@ -825,14 +810,6 @@ def _iter_sample(
825810
826811 point = start
827812
828- strace : BaseTrace = _init_trace (
829- expected_length = draws + tune ,
830- stats_dtypes = step .stats_dtypes ,
831- chain_number = chain ,
832- trace = trace ,
833- model = model ,
834- )
835-
836813 try :
837814 step .tune = bool (tune )
838815 if hasattr (step , "reset_tuning" ):
@@ -846,24 +823,24 @@ def _iter_sample(
846823 if i == tune :
847824 step .stop_tuning ()
848825 point , stats = step .step (point )
849- strace .record (point , stats )
826+ trace .record (point , stats )
850827 log_warning_stats (stats )
851828 diverging = i > tune and stats and stats [0 ].get ("diverging" )
852829 if callback is not None :
853830 callback (
854- trace = strace ,
831+ trace = trace ,
855832 draw = Draw (chain , i == draws , i , i < tune , stats , point ),
856833 )
857834
858- yield strace , diverging
835+ yield diverging
859836 except KeyboardInterrupt :
860- strace .close ()
837+ trace .close ()
861838 raise
862839 except BaseException :
863- strace .close ()
840+ trace .close ()
864841 raise
865842 else :
866- strace .close ()
843+ trace .close ()
867844
868845
869846def _mp_sample (
@@ -876,12 +853,12 @@ def _mp_sample(
876853 random_seed : Sequence [RandomSeed ],
877854 start : Sequence [PointType ],
878855 progressbar : bool = True ,
879- trace : Optional [BaseTrace ] = None ,
856+ traces : Sequence [BaseTrace ],
880857 model = None ,
881858 callback = None ,
882859 mp_ctx = None ,
883860 ** kwargs ,
884- ) -> List [ BaseTrace ] :
861+ ) -> None :
885862 """Main iteration for multiprocess sampling.
886863
887864 Parameters
@@ -913,28 +890,12 @@ def _mp_sample(
913890 the ``draw.chain`` argument can be used to determine which of the active chains the sample
914891 is drawn from.
915892 Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
916-
917- Returns
918- -------
919- traces
920- All chains.
921893 """
922894 import pymc .sampling .parallel as ps
923895
924896 # We did draws += tune in pm.sample
925897 draws -= tune
926898
927- traces = [
928- _init_trace (
929- expected_length = draws + tune ,
930- stats_dtypes = step .stats_dtypes ,
931- chain_number = chain_number ,
932- trace = trace ,
933- model = model ,
934- )
935- for chain_number in range (chains )
936- ]
937-
938899 sampler = ps .ParallelSampler (
939900 draws = draws ,
940901 tune = tune ,
@@ -957,7 +918,7 @@ def _mp_sample(
957918 strace .close ()
958919
959920 if callback is not None :
960- callback (trace = trace , draw = draw )
921+ callback (trace = strace , draw = draw )
961922
962923 except ps .ParallelSamplingError as error :
963924 strace = traces [error ._chain ]
@@ -967,9 +928,8 @@ def _mp_sample(
967928 multitrace = MultiTrace (traces )
968929 multitrace ._report ._log_summary ()
969930 raise
970- return traces
971931 except KeyboardInterrupt :
972- return traces
932+ pass
973933 finally :
974934 for strace in traces :
975935 strace .close ()
0 commit comments