Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ def save(self, path, save_program=False):
with open(path, "wb") as f:
cloudpickle.dump(state, f)
else:
raise ValueError(
f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}"
)
raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

def load(self, path):
"""Load the saved module. You may also want to check out dspy.load, if you want to
Expand Down
5 changes: 3 additions & 2 deletions dspy/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@


def get_dependency_versions():
cloudpickle_version = '.'.join(cloudpickle.__version__.split('.')[:2])
return {
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"python": f"{sys.version_info.major}.{sys.version_info.minor}",
"dspy": importlib_metadata.version("dspy"),
"cloudpickle": cloudpickle.__version__,
"cloudpickle": cloudpickle_version,
}


Expand Down
55 changes: 55 additions & 0 deletions tests/primitives/test_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dspy
import threading
from dspy.utils.dummies import DummyLM
import logging
from unittest.mock import patch


def test_deepcopy_basic():
Expand Down Expand Up @@ -106,3 +108,56 @@ def dummy_metric(example, pred, trace=None):

assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature)
assert new_cot.predict.demos == compiled_cot.predict.demos


def test_load_with_version_mismatch(tmp_path):
from dspy.primitives.module import logger

# Mock versions during save
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}

# Mock versions during load
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}

predict = dspy.Predict("question->answer")

# Create a custom handler to capture log messages
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.messages = []

def emit(self, record):
self.messages.append(record.getMessage())

# Add handler and set level
handler = ListHandler()
original_level = logger.level
logger.addHandler(handler)
logger.setLevel(logging.WARNING)

try:
save_path = tmp_path / "program.pkl"
# Mock version during save
with patch("dspy.primitives.module.get_dependency_versions", return_value=save_versions):
predict.save(save_path)

# Mock version during load
with patch("dspy.primitives.module.get_dependency_versions", return_value=load_versions):
loaded_predict = dspy.Predict("question->answer")
loaded_predict.load(save_path)

# Assert warnings were logged, and one warning for each mismatched dependency.
assert len(handler.messages) == 3

for msg in handler.messages:
assert "There is a mismatch of" in msg

# Verify the model still loads correctly despite version mismatches
assert isinstance(loaded_predict, dspy.Predict)
assert str(predict.signature) == str(loaded_predict.signature)

finally:
# Clean up: restore original level and remove handler
logger.setLevel(original_level)
logger.removeHandler(handler)
55 changes: 55 additions & 0 deletions tests/utils/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import dspy
from dspy.utils import DummyLM
from unittest.mock import patch
import pytest
from dspy.utils.saving import get_dependency_versions
import logging


def test_save_predict(tmp_path):
Expand Down Expand Up @@ -74,3 +78,54 @@ def dummy_metric(example, pred, trace=None):
loaded_predict = dspy.load(tmp_path)
assert compiled_predict.demos == loaded_predict.demos
assert compiled_predict.signature == loaded_predict.signature


def test_load_with_version_mismatch(tmp_path):
from dspy.utils.saving import logger

# Mock versions during save
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}

# Mock versions during load
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}

predict = dspy.Predict("question->answer")

# Create a custom handler to capture log messages
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.messages = []

def emit(self, record):
self.messages.append(record.getMessage())

# Add handler and set level
handler = ListHandler()
original_level = logger.level
logger.addHandler(handler)
logger.setLevel(logging.WARNING)

try:
# Mock version during save
with patch("dspy.utils.saving.get_dependency_versions", return_value=save_versions):
predict.save(tmp_path, save_program=True)

# Mock version during load
with patch("dspy.utils.saving.get_dependency_versions", return_value=load_versions):
loaded_predict = dspy.load(tmp_path)

# Assert warnings were logged, and one warning for each mismatched dependency.
assert len(handler.messages) == 3

for msg in handler.messages:
assert "There is a mismatch of" in msg

# Verify the model still loads correctly despite version mismatches
assert isinstance(loaded_predict, dspy.Predict)
assert predict.signature == loaded_predict.signature

finally:
# Clean up: restore original level and remove handler
logger.setLevel(original_level)
logger.removeHandler(handler)
Loading