Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 82 additions & 24 deletions docs/gb/gb0007.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,94 @@ explanation
<!-- ADDITIONAL INFORMATION START - Add custom information below this line -->
Example code that causes the graph break is:
```python
def fn():
return unittest.skip("test")
compiled_fn = torch.compile(fn, backend="eager", fullgraph=True)
compiled_fn()
import torch
from tqdm import tqdm


def fn(x):
for i in tqdm(range(5)):
x += i
return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
```
The first workaround is to remove the skipped function call since the function unittest.skip is designed for test control flow and should not be part of your model's logic. The correct fix is to remove such calls from any code path that you intend to compile.

The first workaround is to remove the skipped function:
```python
def fn_fixed():
return torch.ones(5) # The actual computation
compiled_fn = torch.compile(fn_fixed, backend="eager", fullgraph=True)
result = compiled_fn()
def fn(x):
for i in range(5):
x += i
return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
```

The second workaround is to override Dynamo's default behavior using the @torch._dynamo.dont_skip_tracing decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function's internals are also untraceable so proceed with caution.
You can use `torch.compiler.is_compiling()` if you only want to remove the function when `torch.compile` is active:

```python
def fn():
return unittest.skip("This is a test skip reason")

with torch._dynamo.config.patch(
skipfile_rules=(
(unittest.case, torch._dynamo.trace_rules.DONT_SKIP),
)
):
try:
# This will now trace INTO unittest.skip, which itself will raise unittest.case.SkipTest. Dynamo will then graph break.
compiled_fn = torch.compile(fn, backend="eager")
compiled_fn()
except unittest.case.SkipTest:
print("\nSuccess: Traced into the skipped function and caught the expected SkipTest exception.")
def fn(x):
iter = range(5)
if not torch.compiler.is_compiling():
iter = tqdm(iter)
for i in iter:
x += i
return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
```

The second workaround is to not compile the skipped function:
```python
@torch.compile(fullgraph=True)
def inner(x, i):
x += i

def fn(x):
for i in tqdm(range(5)):
inner(x, i)
return x

fn(torch.randn(3))
```


The third workaround is to override Dynamo's default skipping behavior using the `@torch._dynamo.dont_skip_tracing` decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function's internals are also untraceable so proceed with caution.
```python
@torch._dynamo.dont_skip_tracing
def fn(x):
for i in tqdm(range(5)):
x += i
return x


compiled_fn = torch.compile(fn, fullgraph=True)
# Another graph break because we attempted to trace into `tqdm.__new__`
compiled_fn(torch.randn(3))
```

If you are attempting to call a logging function (e.g. `_warnings.warn`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`:
```python
import warnings

torch._dynamo.config.reorderable_logging_functions.add(warnings.warn)

def fn(x):
warnings.warn("warning")
for i in range(5):
x += i
return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
```

<!-- ADDITIONAL INFORMATION END -->


Expand Down