1- from datetime import date
21from typing import (
2+ Any ,
3+ Dict ,
34 Final ,
45 List ,
56 Mapping ,
910 cast ,
1011)
1112
12- from pandas import DataFrame
13+ from pandas import CategoricalDtype , DataFrame , Series
1314from requests import Response , Session
1415from requests .auth import HTTPBasicAuth
1516from tenacity import retry , stop_after_attempt
2122from ._model import (
2223 AEpiDataCall ,
2324 EpidataFieldInfo ,
24- EpiDataFormatType ,
25+ EpidataFieldType ,
2526 EpiDataResponse ,
2627 EpiRange ,
2728 EpiRangeParam ,
2829 OnlySupportsClassicFormatException ,
2930 add_endpoint_to_url ,
3031)
32+ from ._parse import fields_to_predicate
3133
3234# Make the linter happy about the unused variables
3335__all__ = ["Epidata" , "EpiDataCall" , "EpiDataContext" , "EpiRange" , "CovidcastEpidata" ]
@@ -83,23 +85,25 @@ def with_session(self, session: Session) -> "EpiDataCall":
8385
8486 def _call (
8587 self ,
86- format_type : Optional [EpiDataFormatType ] = None ,
8788 fields : Optional [Sequence [str ]] = None ,
8889 stream : bool = False ,
8990 ) -> Response :
90- url , params = self .request_arguments (format_type , fields )
91+ url , params = self .request_arguments (fields )
9192 return _request_with_retry (url , params , self ._session , stream )
9293
9394 def classic (
9495 self ,
9596 fields : Optional [Sequence [str ]] = None ,
9697 disable_date_parsing : Optional [bool ] = False ,
98+ disable_type_parsing : Optional [bool ] = False ,
9799 ) -> EpiDataResponse :
98100 """Request and parse epidata in CLASSIC message format."""
99101 self ._verify_parameters ()
100102 try :
101- response = self ._call (None , fields )
103+ response = self ._call (fields )
102104 r = cast (EpiDataResponse , response .json ())
105+ if disable_type_parsing :
106+ return r
103107 epidata = r .get ("epidata" )
104108 if epidata and isinstance (epidata , list ) and len (epidata ) > 0 and isinstance (epidata [0 ], dict ):
105109 r ["epidata" ] = [self ._parse_row (row , disable_date_parsing = disable_date_parsing ) for row in epidata ]
@@ -111,25 +115,11 @@ def __call__(
111115 self ,
112116 fields : Optional [Sequence [str ]] = None ,
113117 disable_date_parsing : Optional [bool ] = False ,
114- ) -> EpiDataResponse :
115- """Request and parse epidata in CLASSIC message format."""
116- return self .classic (fields , disable_date_parsing = disable_date_parsing )
117-
118- def json (
119- self ,
120- fields : Optional [Sequence [str ]] = None ,
121- disable_date_parsing : Optional [bool ] = False ,
122- ) -> List [Mapping [str , Union [str , int , float , date , None ]]]:
123- """Request and parse epidata in JSON format"""
118+ ) -> Union [EpiDataResponse , DataFrame ]:
119+ """Request and parse epidata in df message format."""
124120 if self .only_supports_classic :
125- raise OnlySupportsClassicFormatException ()
126- self ._verify_parameters ()
127- response = self ._call (EpiDataFormatType .json , fields )
128- response .raise_for_status ()
129- return [
130- self ._parse_row (row , disable_date_parsing = disable_date_parsing )
131- for row in cast (List [Mapping [str , Union [str , int , float , None ]]], response .json ())
132- ]
121+ return self .classic (fields , disable_date_parsing = disable_date_parsing , disable_type_parsing = False )
122+ return self .df (fields , disable_date_parsing = disable_date_parsing )
133123
134124 def df (
135125 self ,
@@ -140,37 +130,37 @@ def df(
140130 if self .only_supports_classic :
141131 raise OnlySupportsClassicFormatException ()
142132 self ._verify_parameters ()
143- r = self .json (fields , disable_date_parsing = disable_date_parsing )
144- return self . _as_df ( r , fields , disable_date_parsing = disable_date_parsing )
145-
146- def csv ( self , fields : Optional [ Iterable [ str ]] = None ) -> str :
147- """Request and parse epidata in CSV format"""
148- if self . only_supports_classic :
149- raise OnlySupportsClassicFormatException ()
150- self ._verify_parameters ()
151- response = self . _call ( EpiDataFormatType . csv , fields )
152- response . raise_for_status ()
153- return response . text
154-
155- def iter (
156- self ,
157- fields : Optional [ Iterable [ str ]] = None ,
158- disable_date_parsing : Optional [ bool ] = False ,
159- ) -> Generator [ Mapping [ str , Union [ str , int , float , date , None ]], None , Response ] :
160- """Request and streams epidata rows"" "
161- if self . only_supports_classic :
162- raise OnlySupportsClassicFormatException ()
163- self . _verify_parameters ()
164- response = self . _call ( EpiDataFormatType . jsonl , fields , stream = True )
165- response . raise_for_status ()
166- for line in response . iter_lines ():
167- yield self . _parse_row ( loads ( line ), disable_date_parsing = disable_date_parsing )
168- return response
169-
170- def __iter__ (
171- self ,
172- ) -> Generator [ Mapping [ str , Union [ str , int , float , date , None ]], None , Response ]:
173- return self . iter ()
133+ json = self .classic (fields , disable_type_parsing = True )
134+ rows = json . get ( "epidata" , [] )
135+ pred = fields_to_predicate ( fields )
136+ columns : List [ str ] = [ info . name for info in self . meta if pred ( info . name )]
137+ df = DataFrame ( rows , columns = columns or None )
138+
139+ data_types : Dict [ str , Any ] = {}
140+ for info in self .meta :
141+ if not pred ( info . name ) or df [ info . name ]. isnull (). all ():
142+ continue
143+ if info . type == EpidataFieldType . bool :
144+ data_types [ info . name ] = bool
145+ elif info . type == EpidataFieldType . categorical :
146+ data_types [ info . name ] = CategoricalDtype (
147+ categories = Series ( info . categories ) if info . categories else None , ordered = True
148+ )
149+ elif info . type == EpidataFieldType . int :
150+ data_types [ info . name ] = "Int64 "
151+ elif info . type in (
152+ EpidataFieldType . date ,
153+ EpidataFieldType . epiweek ,
154+ EpidataFieldType . date_or_epiweek ,
155+ ):
156+ data_types [ info . name ] = "Int64" if disable_date_parsing else "datetime64[ns]"
157+ elif info . type == EpidataFieldType . float :
158+ data_types [ info . name ] = "Float64"
159+ else :
160+ data_types [ info . name ] = "string"
161+ if data_types :
162+ df = df . astype ( data_types )
163+ return df
174164
175165
176166class EpiDataContext (AEpiDataEndpoints [EpiDataCall ]):
0 commit comments