diff --git a/ax/ax_types.thrift b/ax/ax_types.thrift new file mode 100644 index 00000000000..20aa24ec0b6 --- /dev/null +++ b/ax/ax_types.thrift @@ -0,0 +1,8 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +include "ax/storage/future/thrift/generation_strategy.thrift" diff --git a/ax/storage/future/generation_strategy.py b/ax/storage/future/generation_strategy.py new file mode 100644 index 00000000000..d2034f043ce --- /dev/null +++ b/ax/storage/future/generation_strategy.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Self + +from ax.storage.future.thrift.generation_strategy.thrift_types import ( + GenerationNode as ThriftGenerationNode, + GenerationStrategy as ThriftGenerationStrategy, +) + +from ax.storage.future.universal import ThriftSerializable, UniversalStruct +from pyre_extensions import assert_is_instance, override + + +@dataclass +class GenerationNode(ThriftSerializable): + name: str + + @classmethod + @override + def thrift_type(cls) -> type[UniversalStruct]: + return ThriftGenerationNode + + @override + def serialize(self) -> ThriftGenerationNode: + return ThriftGenerationNode(name=self.name) + + @classmethod + @override + def deserialize(cls, struct: UniversalStruct) -> Self: + node_struct = assert_is_instance(struct, ThriftGenerationNode) + + return cls(name=node_struct.name) + + +@dataclass +class GenerationStrategy(ThriftSerializable): + name: str + nodes: list[GenerationNode] + current_node_index: int = 0 + + @classmethod + @override + def thrift_type(cls) -> type[UniversalStruct]: + return ThriftGenerationStrategy + + @override + def serialize(self) -> ThriftGenerationStrategy: + return ThriftGenerationStrategy( + name=self.name, + nodes=[node.serialize() for node in self.nodes], + current_node_index=self.current_node_index, + ) + + @classmethod + @override + def deserialize(cls, struct: UniversalStruct) -> Self: + gs_struct = assert_is_instance(struct, ThriftGenerationStrategy) + + return cls( + name=gs_struct.name, + nodes=[GenerationNode.deserialize(node) for node in gs_struct.nodes], + current_node_index=gs_struct.current_node_index, + ) + + +sobol_node = GenerationNode(name="sobol") +mbg_node = GenerationNode(name="modular_botorch_generator") + +gs = GenerationStrategy( + name="gpei", + nodes=[sobol_node, mbg_node], + current_node_index=1, +) diff --git a/ax/storage/future/thrift/generation_strategy.thrift b/ax/storage/future/thrift/generation_strategy.thrift new file mode 100644 index 00000000000..1617ab55627 --- /dev/null +++ b/ax/storage/future/thrift/generation_strategy.thrift @@ -0,0 +1,18 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +namespace py3 ax.storage.future.thrift + +struct GenerationNode { + 1: string name; +} + +struct GenerationStrategy { + 1: string name; + 2: list nodes; + 3: i32 current_node_index; +} diff --git a/ax/storage/future/universal.py b/ax/storage/future/universal.py new file mode 100644 index 00000000000..89c073bdded --- /dev/null +++ b/ax/storage/future/universal.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Self + + +try: + from thrift.python.serializer import deserialize, Protocol, serialize + from thrift.python.types import Struct + + UniversalStruct = Struct + JSONProtocol = Protocol.JSON +except ImportError: # Use Apache Thrift if thrift-python is not available. + from typing import Any + + from thrift.protocol import TJSONProtocol + from thrift.transport import TTransport + + def serialize(struct, protocol): + transport = TTransport.TMemoryBuffer() + proto = protocol.getProtocol(transport) + struct.write(proto) + + return transport.getvalue() + + def deserialize(klass, buf, protocol): + transport = TTransport.TMemoryBuffer(buf) + proto = protocol.getProtocol(transport) + obj = klass() + obj.read(proto) + return obj + + UniversalStruct = Any + JSONProtocol = TJSONProtocol.TJSONProtocolFactory() + + +class ThriftSerializable(ABC): + @classmethod + @abstractmethod + def thrift_type(cls) -> type[UniversalStruct]: ... + + @abstractmethod + def serialize(self) -> UniversalStruct: ... + + @classmethod + @abstractmethod + def deserialize(cls, struct: UniversalStruct) -> Self: ... + + def to_json(self) -> str: + struct = self.serialize() + + return serialize(struct, protocol=JSONProtocol).decode("utf-8") + + @classmethod + def from_json(cls, json_str: str) -> Self: + struct = deserialize( + cls.thrift_type(), json_str.encode("utf-8"), protocol=JSONProtocol + ) + + return cls.deserialize(struct=struct) diff --git a/build_thrift.py b/build_thrift.py new file mode 100644 index 00000000000..1ea43a345ac --- /dev/null +++ b/build_thrift.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import subprocess + +from setuptools.command.build_py import build_py as _build_py + + +class build_py(_build_py): + def run(self): + subprocess.check_call(["thrift", "-r", "--gen", "py", "ax/ax_types.thrift"]) + + for root, _dirs, files in os.walk("gen-py"): + for filename in files: + if filename == "ttypes.py": + old_path = os.path.join(root, filename) + new_path = os.path.join(root, "thrift_types.py") + os.rename(old_path, new_path) + print(f"Renamed {old_path} to {new_path}") + + if os.path.exists("gen-py"): + for item in os.listdir("gen-py"): + src = os.path.join("gen-py", item) + dst = os.path.join("ax", item) + if os.path.isdir(src): + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + print(f"Copied {src} to {dst}") + elif os.path.isfile(src): + shutil.copy2(src, dst) + print(f"Copied {src} to {dst}") + + super().run() diff --git a/pyproject.toml b/pyproject.toml index 6885300f09a..61058141b97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pyre-extensions", "sympy", "markdown", + "thrift", ] [project.optional-dependencies] @@ -86,6 +87,9 @@ find = {} [tool.setuptools.package-data] "*" = ["*.js", "*.css", "*.html"] +[tool.setuptools.cmdclass] +build_py = "build_thrift:build_py" + [tool.setuptools_scm] write_to = "ax/version.py" local_scheme = "node-and-date"