Skip to content

Commit f971d06

Browse files
committed
tests: add to tests and create workflow for py extra backends
1 parent 6d2001a commit f971d06

File tree

4 files changed

+114
-7
lines changed

4 files changed

+114
-7
lines changed

.github/workflows/test-extra.yml

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
---
2+
name: 'Tests extras backends'
3+
4+
on:
5+
pull_request:
6+
push:
7+
branches:
8+
- master
9+
tags:
10+
- '*'
11+
12+
concurrency:
13+
group: ci-tests-extra-${{ github.head_ref || github.ref }}-${{ github.repository }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
tests-linux:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- name: Release space from worker
21+
run: |
22+
echo "Listing top largest packages"
23+
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
24+
head -n 30 <<< "${pkgs}"
25+
echo
26+
df -h
27+
echo
28+
sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
29+
sudo apt-get remove --auto-remove android-sdk-platform-tools || true
30+
sudo apt-get purge --auto-remove android-sdk-platform-tools || true
31+
sudo rm -rf /usr/local/lib/android
32+
sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
33+
sudo rm -rf /usr/share/dotnet
34+
sudo apt-get remove -y '^mono-.*' || true
35+
sudo apt-get remove -y '^ghc-.*' || true
36+
sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
37+
sudo apt-get remove -y 'php.*' || true
38+
sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
39+
sudo apt-get remove -y '^google-.*' || true
40+
sudo apt-get remove -y azure-cli || true
41+
sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
42+
sudo apt-get remove -y '^gfortran-.*' || true
43+
sudo apt-get autoremove -y
44+
sudo apt-get clean
45+
echo
46+
echo "Listing top largest packages"
47+
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
48+
head -n 30 <<< "${pkgs}"
49+
echo
50+
sudo rm -rfv build || true
51+
df -h
52+
- name: Clone
53+
uses: actions/checkout@v4
54+
with:
55+
submodules: true
56+
- name: Dependencies
57+
run: |
58+
sudo apt-get update
59+
sudo apt-get install build-essential ffmpeg
60+
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
61+
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
62+
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
63+
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
64+
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
65+
sudo apt-get update && \
66+
sudo apt-get install -y conda
67+
sudo apt-get install -y ca-certificates cmake curl patch
68+
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
69+
70+
sudo rm -rfv /usr/bin/conda || true
71+
72+
- name: Test
73+
run: |
74+
PATH=$PATH:/opt/conda/bin make test-extra
75+

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ prepare-extra-conda-environments:
414414
$(MAKE) -C backend/python/petals
415415
$(MAKE) -C backend/python/exllama2
416416

417+
prepare-test-extra:
418+
$(MAKE) -C backend/python/transformers
419+
420+
test-extra: prepare-test-extra
421+
$(MAKE) -C backend/python/transformers test
417422

418423
backend-assets/grpc:
419424
mkdir -p backend-assets/grpc

backend/python/transformers/test_transformers_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_server_startup(self):
3131
"""
3232
This method tests if the server starts up successfully
3333
"""
34-
time.sleep(2)
34+
time.sleep(10)
3535
try:
3636
self.setUp()
3737
with grpc.insecure_channel("localhost:50051") as channel:
@@ -48,11 +48,12 @@ def test_load_model(self):
4848
"""
4949
This method tests if the model is loaded successfully
5050
"""
51+
time.sleep(10)
5152
try:
5253
self.setUp()
5354
with grpc.insecure_channel("localhost:50051") as channel:
5455
stub = backend_pb2_grpc.BackendStub(channel)
55-
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
56+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
5657
self.assertTrue(response.success)
5758
self.assertEqual(response.message, "Model loaded successfully")
5859
except Exception as err:
@@ -65,11 +66,12 @@ def test_embedding(self):
6566
"""
6667
This method tests if the embeddings are generated successfully
6768
"""
69+
time.sleep(10)
6870
try:
6971
self.setUp()
7072
with grpc.insecure_channel("localhost:50051") as channel:
7173
stub = backend_pb2_grpc.BackendStub(channel)
72-
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
74+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
7375
print(response.message)
7476
self.assertTrue(response.success)
7577
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")

backend/python/transformers/transformers_server.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,27 @@
1414
import backend_pb2_grpc
1515

1616
import grpc
17+
import torch
1718

18-
from transformers import AutoModel
19+
from transformers import AutoTokenizer, AutoModel
1920

2021
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
2122

2223
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
2324
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
2425

26+
27+
def mean_pooling(model_output, attention_mask):
28+
"""
29+
Mean pooling to get sentence embeddings. See:
30+
https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1
31+
"""
32+
token_embeddings = model_output[0]
33+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns
35+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
36+
return sum_embeddings / sum_mask
37+
2538
# Implement the BackendServicer class with the service methods
2639
class BackendServicer(backend_pb2_grpc.BackendServicer):
2740
"""
@@ -56,6 +69,8 @@ def LoadModel(self, request, context):
5669
model_name = request.Model
5770
try:
5871
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
72+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
73+
5974
if request.CUDA:
6075
try:
6176
# TODO: also tensorflow, make configurable
@@ -82,10 +97,20 @@ def Embedding(self, request, context):
8297
Returns:
8398
An EmbeddingResult object that contains the calculated embeddings.
8499
"""
85-
# Implement your logic here for the Embedding service
86-
# Replace this with your desired response
100+
101+
# Tokenize input
102+
max_length = 512
103+
if request.Tokens != 0:
104+
max_length = request.Tokens
105+
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
106+
107+
# Create word embeddings
108+
model_output = self.model(**encoded_input)
109+
110+
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
111+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
87112
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
88-
sentence_embeddings = self.model.encode(request.Embeddings)
113+
print("Embeddings:", sentence_embeddings, file=sys.stderr)
89114
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
90115

91116

0 commit comments

Comments
 (0)