Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/async-await.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ not work with async views because they will not await the function or be
awaitable. Other functions they provide will not be awaitable either and
will probably be blocking if called within an async view.

Extension authors can support async functions by utilising the
:meth:`flask.Flask.ensure_sync` method. For example, if the extension
provides a view function decorator add ``ensure_sync`` before calling
the decorated function,

.. code-block:: python

def extension(func):
@wraps(func)
def wrapper(*args, **kwargs):
... # Extension logic
return current_app.ensure_sync(func)(*args, **kwargs)

return wrapper

Check the changelog of the extension you want to use to see if they've
implemented async support, or make a feature request or PR to them.

Expand Down
60 changes: 44 additions & 16 deletions src/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from werkzeug.exceptions import BadRequestKeyError
from werkzeug.exceptions import HTTPException
from werkzeug.exceptions import InternalServerError
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError
from werkzeug.routing import Map
from werkzeug.routing import MapAdapter
Expand All @@ -40,7 +41,6 @@
from .helpers import get_flashed_messages
from .helpers import get_load_dotenv
from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import url_for
from .json import jsonify
from .logging import create_logger
Expand Down Expand Up @@ -1080,14 +1080,12 @@ def add_url_rule(
self.url_map.add(rule)
if view_func is not None:
old_func = self.view_functions.get(endpoint)
if getattr(old_func, "_flask_sync_wrapper", False):
old_func = old_func.__wrapped__ # type: ignore
if old_func is not None and old_func != view_func:
raise AssertionError(
"View function mapping is overwriting an existing"
f" endpoint function: {endpoint}"
)
self.view_functions[endpoint] = self.ensure_sync(view_func)
self.view_functions[endpoint] = view_func

@setupmethod
def template_filter(self, name: t.Optional[str] = None) -> t.Callable:
Expand Down Expand Up @@ -1208,7 +1206,7 @@ def before_first_request(self, f: BeforeRequestCallable) -> BeforeRequestCallabl

.. versionadded:: 0.8
"""
self.before_first_request_funcs.append(self.ensure_sync(f))
self.before_first_request_funcs.append(f)
return f

@setupmethod
Expand Down Expand Up @@ -1241,7 +1239,7 @@ def teardown_appcontext(self, f: TeardownCallable) -> TeardownCallable:

.. versionadded:: 0.9
"""
self.teardown_appcontext_funcs.append(self.ensure_sync(f))
self.teardown_appcontext_funcs.append(f)
return f

@setupmethod
Expand Down Expand Up @@ -1308,7 +1306,7 @@ def handle_http_exception(
handler = self._find_error_handler(e)
if handler is None:
return e
return handler(e)
return self.ensure_sync(handler)(e)

def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default
Expand Down Expand Up @@ -1375,7 +1373,7 @@ def handle_user_exception(
if handler is None:
raise

return handler(e)
return self.ensure_sync(handler)(e)

def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
Expand Down Expand Up @@ -1422,7 +1420,7 @@ def handle_exception(self, e: Exception) -> Response:
handler = self._find_error_handler(server_error)

if handler is not None:
server_error = handler(server_error)
server_error = self.ensure_sync(handler)(server_error)

return self.finalize_request(server_error, from_error_handler=True)

Expand Down Expand Up @@ -1484,7 +1482,7 @@ def dispatch_request(self) -> ResponseReturnValue:
):
return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args)
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)

def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request
Expand Down Expand Up @@ -1545,7 +1543,7 @@ def try_trigger_before_first_request_functions(self) -> None:
if self._got_first_request:
return
for func in self.before_first_request_funcs:
func()
self.ensure_sync(func)()
self._got_first_request = True

def make_default_options_response(self) -> Response:
Expand Down Expand Up @@ -1581,10 +1579,40 @@ def ensure_sync(self, func: t.Callable) -> t.Callable:
.. versionadded:: 2.0
"""
if iscoroutinefunction(func):
return run_async(func)
return self.async_to_sync(func)

return func

def async_to_sync(
self, func: t.Callable[..., t.Coroutine]
) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function.

.. code-block:: python

result = app.async_to_sync(func)(*args, **kwargs)

Override this method to change how the app converts async code
to be synchronously callable.

.. versionadded:: 2.0
"""
try:
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)

# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python "
"and Greenlet versions."
)

return asgiref_async_to_sync(func)

def make_response(self, rv: ResponseReturnValue) -> Response:
"""Convert the return value from a view function to an instance of
:attr:`response_class`.
Expand Down Expand Up @@ -1807,7 +1835,7 @@ def preprocess_request(self) -> t.Optional[ResponseReturnValue]:
if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp])
for func in funcs:
rv = func()
rv = self.ensure_sync(func)()
if rv is not None:
return rv

Expand All @@ -1834,7 +1862,7 @@ def process_response(self, response: Response) -> Response:
if None in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
for handler in funcs:
response = handler(response)
response = self.ensure_sync(handler)(response)
if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)
return response
Expand Down Expand Up @@ -1871,7 +1899,7 @@ def do_teardown_request(
if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs:
func(exc)
self.ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc)

def do_teardown_appcontext(
Expand All @@ -1894,7 +1922,7 @@ def do_teardown_appcontext(
if exc is _sentinel:
exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs):
func(exc)
self.ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc)

def app_context(self) -> AppContext:
Expand Down
40 changes: 8 additions & 32 deletions src/flask/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,10 @@ def register(self, app: "Flask", options: dict) -> None:
# Merge blueprint data into parent.
if first_registration:

def extend(bp_dict, parent_dict, ensure_sync=False):
def extend(bp_dict, parent_dict):
for key, values in bp_dict.items():
key = self.name if key is None else f"{self.name}.{key}"

if ensure_sync:
values = [app.ensure_sync(func) for func in values]

parent_dict[key].extend(values)

for key, value in self.error_handler_spec.items():
Expand All @@ -307,25 +304,21 @@ def extend(bp_dict, parent_dict, ensure_sync=False):
dict,
{
code: {
exc_class: app.ensure_sync(func)
for exc_class, func in code_values.items()
exc_class: func for exc_class, func in code_values.items()
}
for code, code_values in value.items()
},
)
app.error_handler_spec[key] = value

for endpoint, func in self.view_functions.items():
app.view_functions[endpoint] = app.ensure_sync(func)
app.view_functions[endpoint] = func

extend(
self.before_request_funcs, app.before_request_funcs, ensure_sync=True
)
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
extend(self.before_request_funcs, app.before_request_funcs)
extend(self.after_request_funcs, app.after_request_funcs)
extend(
self.teardown_request_funcs,
app.teardown_request_funcs,
ensure_sync=True,
)
extend(self.url_default_functions, app.url_default_functions)
extend(self.url_value_preprocessors, app.url_value_preprocessors)
Expand Down Expand Up @@ -478,9 +471,7 @@ def before_app_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable:
before each request, even if outside of a blueprint.
"""
self.record_once(
lambda s: s.app.before_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
)
return f

Expand All @@ -490,19 +481,15 @@ def before_app_first_request(
"""Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application.
"""
self.record_once(
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
)
self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
return f

def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
"""Like :meth:`Flask.after_request` but for a blueprint. Such a function
is executed after each request, even if outside of the blueprint.
"""
self.record_once(
lambda s: s.app.after_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
)
return f

Expand Down Expand Up @@ -553,14 +540,3 @@ def app_url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable:
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
)
return f

def ensure_sync(self, f: t.Callable) -> t.Callable:
"""Ensure the function is synchronous.

Override if you would like custom async to sync behaviour in
this blueprint. Otherwise the app's
:meth:`~flask.Flask.ensure_sync` is used.

.. versionadded:: 2.0
"""
return f
50 changes: 0 additions & 50 deletions src/flask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import warnings
from datetime import timedelta
from functools import update_wrapper
from functools import wraps
from threading import RLock

import werkzeug.utils
from werkzeug.exceptions import NotFound
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError
from werkzeug.urls import url_quote

Expand Down Expand Up @@ -801,51 +799,3 @@ def is_ip(value: str) -> bool:
return True

return False


def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
try:
from asgiref.sync import async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)

# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python & Greenlet versions."
)

@wraps(func)
def outer(*args: t.Any, **kwargs: t.Any) -> t.Any:
"""This function grabs the current context for the inner function.

This is similar to the copy_current_xxx_context functions in the
ctx module, except it has an async inner.
"""
ctx = None

if _request_ctx_stack.top is not None:
ctx = _request_ctx_stack.top.copy()

@wraps(func)
async def inner(*a: t.Any, **k: t.Any) -> t.Any:
"""This restores the context before awaiting the func.

This is required as the function must be awaited within the
context. Only calling ``func`` (as per the
``copy_current_xxx_context`` functions) doesn't work as the
with block will close before the coroutine is awaited.
"""
if ctx is not None:
with ctx:
return await func(*a, **k)
else:
return await func(*a, **k)

return async_to_sync(inner)(*args, **kwargs)

outer._flask_sync_wrapper = True # type: ignore
return outer
Loading