diff --git a/src/hugchat/login.py b/src/hugchat/login.py index 6573bc7..e0746dd 100644 --- a/src/hugchat/login.py +++ b/src/hugchat/login.py @@ -103,11 +103,11 @@ def loadCookiesFromDir(self, cookie_dir_path: str = './usercookies') -> requests raise Exception( "load cookies from files fatal. Please check the format") - def _request_get(self, url: str, params=None, allow_redirects=True) -> requests.Response: + def _request_get(self, url: str, params=None, allow_redirects=True, headers=None) -> requests.Response: res = requests.get( url, params=params, - headers=self.headers, + headers=self.headers if headers is None else headers, cookies=self.cookies, allow_redirects=allow_redirects, ) @@ -156,7 +156,7 @@ def _get_auth_url(self): "Content-Type": "application/x-www-form-urlencoded", "Origin": "https://huggingface.co/chat" } - res = self._request_post(url, headers=headers, allow_redirects=False) + res = self._request_get(url, headers=headers, allow_redirects=False) if res.status_code == 200: # location = res.headers.get("Location", None) location = res.json()["location"] @@ -165,7 +165,7 @@ def _get_auth_url(self): else: raise Exception( "No authorize url found, please check your email or password.") - elif res.status_code == 303: + elif res.status_code == 303 or res.status_code == 302: location = res.headers.get("Location") if location: return location