| 
22 | 22 | import tempfile  | 
23 | 23 | import unittest  | 
24 | 24 | import warnings  | 
 | 25 | +from pathlib import Path  | 
25 | 26 | 
 
  | 
26 | 27 | import numpy as np  | 
27 | 28 | import pytest  | 
@@ -4995,6 +4996,27 @@ def test_custom_generate_requires_trust_remote_code(self):  | 
4995 | 4996 |         with self.assertRaises(ValueError):  | 
4996 | 4997 |             model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example")  | 
4997 | 4998 | 
 
  | 
 | 4999 | +    def test_custom_generate_local_directory(self):  | 
 | 5000 | +        """Tests that custom_generate works with local directories containing importable relative modules"""  | 
 | 5001 | +        with tempfile.TemporaryDirectory() as tmp_dir:  | 
 | 5002 | +            custom_generate_dir = Path(tmp_dir) / "custom_generate"  | 
 | 5003 | +            custom_generate_dir.mkdir()  | 
 | 5004 | +            with open(custom_generate_dir / "generate.py", "w") as f:  | 
 | 5005 | +                f.write("from .helper import ret_success\ndef generate(*args, **kwargs):\n    return ret_success()\n")  | 
 | 5006 | +            with open(custom_generate_dir / "helper.py", "w") as f:  | 
 | 5007 | +                f.write('def ret_success():\n    return "success"\n')  | 
 | 5008 | +            model = AutoModelForCausalLM.from_pretrained(  | 
 | 5009 | +                "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"  | 
 | 5010 | +            )  | 
 | 5011 | +            tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")  | 
 | 5012 | +            model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)  | 
 | 5013 | +            value = model.generate(  | 
 | 5014 | +                **model_inputs,  | 
 | 5015 | +                custom_generate=str(tmp_dir),  | 
 | 5016 | +                trust_remote_code=True,  | 
 | 5017 | +            )  | 
 | 5018 | +            assert value == "success"  | 
 | 5019 | + | 
4998 | 5020 | 
 
  | 
4999 | 5021 | @require_torch  | 
5000 | 5022 | class TokenHealingTestCase(unittest.TestCase):  | 
 | 
0 commit comments