diff --git a/docs/gb/gb0007.md b/docs/gb/gb0007.md index 7f86900..72b0961 100644 --- a/docs/gb/gb0007.md +++ b/docs/gb/gb0007.md @@ -29,36 +29,94 @@ explanation Example code that causes the graph break is: ```python -def fn(): - return unittest.skip("test") -compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) -compiled_fn() +import torch +from tqdm import tqdm + + +def fn(x): + for i in tqdm(range(5)): + x += i + return x + + +compiled_fn = torch.compile(fn, fullgraph=True) +compiled_fn(torch.randn(3)) ``` -The first workaround is to remove the skipped function call since the function unittest.skip is designed for test control flow and should not be part of your model's logic. The correct fix is to remove such calls from any code path that you intend to compile. + +The first workaround is to remove the skipped function: ```python -def fn_fixed(): - return torch.ones(5) # The actual computation -compiled_fn = torch.compile(fn_fixed, backend="eager", fullgraph=True) -result = compiled_fn() +def fn(x): + for i in range(5): + x += i + return x + + +compiled_fn = torch.compile(fn, fullgraph=True) +compiled_fn(torch.randn(3)) ``` -The second workaround is to override Dynamo's default behavior using the @torch._dynamo.dont_skip_tracing decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function's internals are also untraceable so proceed with caution. +You can use `torch.compiler.is_compiling()` if you only want to remove the function when `torch.compile` is active: + ```python -def fn(): - return unittest.skip("This is a test skip reason") - -with torch._dynamo.config.patch( - skipfile_rules=( - (unittest.case, torch._dynamo.trace_rules.DONT_SKIP), - ) -): - try: - # This will now trace INTO unittest.skip, which itself will raise unittest.case.SkipTest. Dynamo will then graph break. - compiled_fn = torch.compile(fn, backend="eager") - compiled_fn() - except unittest.case.SkipTest: - print("\nSuccess: Traced into the skipped function and caught the expected SkipTest exception.") +def fn(x): + iter = range(5) + if not torch.compiler.is_compiling(): + iter = tqdm(iter) + for i in iter: + x += i + return x + + +compiled_fn = torch.compile(fn, fullgraph=True) +compiled_fn(torch.randn(3)) +``` + +The second workaround is to not compile the skipped function: +```python +@torch.compile(fullgraph=True) +def inner(x, i): + x += i + +def fn(x): + for i in tqdm(range(5)): + inner(x, i) + return x + +fn(torch.randn(3)) ``` + + +The third workaround is to override Dynamo's default skipping behavior using the `@torch._dynamo.dont_skip_tracing` decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function's internals are also untraceable so proceed with caution. +```python +@torch._dynamo.dont_skip_tracing +def fn(x): + for i in tqdm(range(5)): + x += i + return x + + +compiled_fn = torch.compile(fn, fullgraph=True) +# Another graph break because we attempted to trace into `tqdm.__new__` +compiled_fn(torch.randn(3)) +``` + +If you are attempting to call a logging function (e.g. `_warnings.warn`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`: +```python +import warnings + +torch._dynamo.config.reorderable_logging_functions.add(warnings.warn) + +def fn(x): + warnings.warn("warning") + for i in range(5): + x += i + return x + + +compiled_fn = torch.compile(fn, fullgraph=True) +compiled_fn(torch.randn(3)) +``` +