Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ax/ax_types.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

include "ax/storage/future/thrift/generation_strategy.thrift"
77 changes: 77 additions & 0 deletions ax/storage/future/generation_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Self

from ax.storage.future.thrift.generation_strategy.thrift_types import (
GenerationNode as ThriftGenerationNode,
GenerationStrategy as ThriftGenerationStrategy,
)

from ax.storage.future.universal import ThriftSerializable, UniversalStruct
from pyre_extensions import assert_is_instance, override


@dataclass
class GenerationNode(ThriftSerializable):
name: str

@classmethod
@override
def thrift_type(cls) -> type[UniversalStruct]:
return ThriftGenerationNode

@override
def serialize(self) -> ThriftGenerationNode:
return ThriftGenerationNode(name=self.name)

@classmethod
@override
def deserialize(cls, struct: UniversalStruct) -> Self:
node_struct = assert_is_instance(struct, ThriftGenerationNode)

return cls(name=node_struct.name)


@dataclass
class GenerationStrategy(ThriftSerializable):
name: str
nodes: list[GenerationNode]
current_node_index: int = 0

@classmethod
@override
def thrift_type(cls) -> type[UniversalStruct]:
return ThriftGenerationStrategy

@override
def serialize(self) -> ThriftGenerationStrategy:
return ThriftGenerationStrategy(
name=self.name,
nodes=[node.serialize() for node in self.nodes],
current_node_index=self.current_node_index,
)

@classmethod
@override
def deserialize(cls, struct: UniversalStruct) -> Self:
gs_struct = assert_is_instance(struct, ThriftGenerationStrategy)

return cls(
name=gs_struct.name,
nodes=[GenerationNode.deserialize(node) for node in gs_struct.nodes],
current_node_index=gs_struct.current_node_index,
)


sobol_node = GenerationNode(name="sobol")
mbg_node = GenerationNode(name="modular_botorch_generator")

gs = GenerationStrategy(
name="gpei",
nodes=[sobol_node, mbg_node],
current_node_index=1,
)
18 changes: 18 additions & 0 deletions ax/storage/future/thrift/generation_strategy.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

namespace py3 ax.storage.future.thrift

struct GenerationNode {
1: string name;
}

struct GenerationStrategy {
1: string name;
2: list<GenerationNode> nodes;
3: i32 current_node_index;
}
63 changes: 63 additions & 0 deletions ax/storage/future/universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Self


try:
from thrift.python.serializer import deserialize, Protocol, serialize
from thrift.python.types import Struct

UniversalStruct = Struct
JSONProtocol = Protocol.JSON
except ImportError: # Use Apache Thrift if thrift-python is not available.
from typing import Any

from thrift.protocol import TJSONProtocol
from thrift.transport import TTransport

def serialize(struct, protocol):
transport = TTransport.TMemoryBuffer()
proto = protocol.getProtocol(transport)
struct.write(proto)

return transport.getvalue()

def deserialize(klass, buf, protocol):
transport = TTransport.TMemoryBuffer(buf)
proto = protocol.getProtocol(transport)
obj = klass()
obj.read(proto)
return obj

UniversalStruct = Any
JSONProtocol = TJSONProtocol.TJSONProtocolFactory()


class ThriftSerializable(ABC):
@classmethod
@abstractmethod
def thrift_type(cls) -> type[UniversalStruct]: ...

@abstractmethod
def serialize(self) -> UniversalStruct: ...

@classmethod
@abstractmethod
def deserialize(cls, struct: UniversalStruct) -> Self: ...

def to_json(self) -> str:
struct = self.serialize()

return serialize(struct, protocol=JSONProtocol).decode("utf-8")

@classmethod
def from_json(cls, json_str: str) -> Self:
struct = deserialize(
cls.thrift_type(), json_str.encode("utf-8"), protocol=JSONProtocol
)

return cls.deserialize(struct=struct)
38 changes: 38 additions & 0 deletions build_thrift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import subprocess

from setuptools.command.build_py import build_py as _build_py


class build_py(_build_py):
def run(self):
subprocess.check_call(["thrift", "-r", "--gen", "py", "ax/ax_types.thrift"])

for root, _dirs, files in os.walk("gen-py"):
for filename in files:
if filename == "ttypes.py":
old_path = os.path.join(root, filename)
new_path = os.path.join(root, "thrift_types.py")
os.rename(old_path, new_path)
print(f"Renamed {old_path} to {new_path}")

if os.path.exists("gen-py"):
for item in os.listdir("gen-py"):
src = os.path.join("gen-py", item)
dst = os.path.join("ax", item)
if os.path.isdir(src):
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst)
print(f"Copied {src} to {dst}")
elif os.path.isfile(src):
shutil.copy2(src, dst)
print(f"Copied {src} to {dst}")

super().run()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"pyre-extensions",
"sympy",
"markdown",
"thrift",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -86,6 +87,9 @@ find = {}
[tool.setuptools.package-data]
"*" = ["*.js", "*.css", "*.html"]

[tool.setuptools.cmdclass]
build_py = "build_thrift:build_py"

[tool.setuptools_scm]
write_to = "ax/version.py"
local_scheme = "node-and-date"
Expand Down
Loading