- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[tests] tests for compilation + quantization (bnb) #11672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from 9 commits
      Commits
    
    
            Show all changes
          
          
            12 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      6fe2414
              
                start adding compilation tests for quantization.
              
              
                sayakpaul 29cca99
              
                fixes
              
              
                sayakpaul 0e2f5b4
              
                Merge branch 'main' into quant-compile-tests
              
              
                sayakpaul edf66b7
              
                make common utility.
              
              
                sayakpaul 11cfd6c
              
                modularize.
              
              
                sayakpaul 0e4f152
              
                add group offloading+compile
              
              
                sayakpaul d3010dd
              
                xfail
              
              
                sayakpaul af57070
              
                update
              
              
                sayakpaul 90dcbd2
              
                Merge branch 'main' into quant-compile-tests
              
              
                sayakpaul 6f5df29
              
                Update tests/quantization/test_torch_compile_utils.py
              
              
                sayakpaul d44a29d
              
                fixes
              
              
                sayakpaul fb8ec95
              
                Merge branch 'main' into quant-compile-tests
              
              
                sayakpaul File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2024 The HuggingFace Team Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a clone of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import gc | ||
| import unittest | ||
|  | ||
| import torch | ||
|  | ||
| from diffusers import DiffusionPipeline | ||
| from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device | ||
|  | ||
|  | ||
| @require_torch_gpu | ||
| @slow | ||
| class QuantCompileMiscTests(unittest.TestCase): | ||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| quantization_config = None | ||
|  | ||
| def setUp(self): | ||
| super().setUp() | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
| torch.compiler.reset() | ||
|  | ||
| def tearDown(self): | ||
| super().tearDown() | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
| torch.compiler.reset() | ||
|  | ||
| def _init_pipeline(self, quantization_config, torch_dtype): | ||
| pipe = DiffusionPipeline.from_pretrained( | ||
| "stabilityai/stable-diffusion-3-medium-diffusers", | ||
| quantization_config=quantization_config, | ||
| torch_dtype=torch_dtype, | ||
| ) | ||
| return pipe | ||
|  | ||
| def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): | ||
| pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") | ||
| # import to ensure fullgraph True | ||
| pipe.transformer.compile(fullgraph=True) | ||
|  | ||
| for _ in range(2): | ||
| # small resolutions to ensure speedy execution. | ||
| pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) | ||
|  | ||
| def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): | ||
| pipe = self._init_pipeline(quantization_config, torch_dtype) | ||
| pipe.enable_model_cpu_offload() | ||
| pipe.transformer.compile() | ||
|  | ||
| for _ in range(2): | ||
| # small resolutions to ensure speedy execution. | ||
| pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) | ||
|  | ||
| def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): | ||
| torch._dynamo.config.cache_size_limit = 10000 | ||
|  | ||
| pipe = self._init_pipeline(quantization_config, torch_dtype) | ||
| group_offload_kwargs = { | ||
| "onload_device": torch.device("cuda"), | ||
| "offload_device": torch.device("cpu"), | ||
| "offload_type": "leaf_level", | ||
| "use_stream": True, | ||
| "non_blocking": True, | ||
| } | ||
| pipe.transformer.enable_group_offload(**group_offload_kwargs) | ||
| pipe.transformer.compile() | ||
| for name, component in pipe.components.items(): | ||
| if name != "transformer" and isinstance(component, torch.nn.Module): | ||
| if torch.device(component.device).type == "cpu": | ||
| component.to("cuda") | ||
|  | ||
| for _ in range(2): | ||
| # small resolutions to ensure speedy execution. | ||
| pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.