Skip to content

Commit 887b3df

Browse files
mudlerkno10
andauthored
feat: cuda transformers (#1401)
* Use cuda in transformers if available tensorflow probably needs a different check. Signed-off-by: Erich Schubert <[email protected]> * feat: expose CUDA at top level Signed-off-by: Ettore Di Giacinto <[email protected]> * tests: add to tests and create workflow for py extra backends * doc: update note on how to use core images --------- Signed-off-by: Erich Schubert <[email protected]> Signed-off-by: Ettore Di Giacinto <[email protected]> Co-authored-by: Erich Schubert <[email protected]>
1 parent 3822bd2 commit 887b3df

File tree

9 files changed

+163
-11
lines changed

9 files changed

+163
-11
lines changed

.github/workflows/test-extra.yml

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
---
2+
name: 'Tests extras backends'
3+
4+
on:
5+
pull_request:
6+
push:
7+
branches:
8+
- master
9+
tags:
10+
- '*'
11+
12+
concurrency:
13+
group: ci-tests-extra-${{ github.head_ref || github.ref }}-${{ github.repository }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
tests-linux:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- name: Release space from worker
21+
run: |
22+
echo "Listing top largest packages"
23+
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
24+
head -n 30 <<< "${pkgs}"
25+
echo
26+
df -h
27+
echo
28+
sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
29+
sudo apt-get remove --auto-remove android-sdk-platform-tools || true
30+
sudo apt-get purge --auto-remove android-sdk-platform-tools || true
31+
sudo rm -rf /usr/local/lib/android
32+
sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
33+
sudo rm -rf /usr/share/dotnet
34+
sudo apt-get remove -y '^mono-.*' || true
35+
sudo apt-get remove -y '^ghc-.*' || true
36+
sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
37+
sudo apt-get remove -y 'php.*' || true
38+
sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
39+
sudo apt-get remove -y '^google-.*' || true
40+
sudo apt-get remove -y azure-cli || true
41+
sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
42+
sudo apt-get remove -y '^gfortran-.*' || true
43+
sudo apt-get autoremove -y
44+
sudo apt-get clean
45+
echo
46+
echo "Listing top largest packages"
47+
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
48+
head -n 30 <<< "${pkgs}"
49+
echo
50+
sudo rm -rfv build || true
51+
df -h
52+
- name: Clone
53+
uses: actions/checkout@v4
54+
with:
55+
submodules: true
56+
- name: Dependencies
57+
run: |
58+
sudo apt-get update
59+
sudo apt-get install build-essential ffmpeg
60+
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
61+
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
62+
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
63+
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
64+
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
65+
sudo apt-get update && \
66+
sudo apt-get install -y conda
67+
sudo apt-get install -y ca-certificates cmake curl patch
68+
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
69+
70+
sudo rm -rfv /usr/bin/conda || true
71+
72+
- name: Test
73+
run: |
74+
PATH=$PATH:/opt/conda/bin make test-extra
75+

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ prepare-extra-conda-environments:
414414
$(MAKE) -C backend/python/petals
415415
$(MAKE) -C backend/python/exllama2
416416

417+
prepare-test-extra:
418+
$(MAKE) -C backend/python/transformers
419+
420+
test-extra: prepare-test-extra
421+
$(MAKE) -C backend/python/transformers test
417422

418423
backend-assets/grpc:
419424
mkdir -p backend-assets/grpc

api/backend/image.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
1616
model.WithContext(o.Context),
1717
model.WithModel(c.Model),
1818
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
19-
CUDA: c.Diffusers.CUDA,
19+
CUDA: c.CUDA,
2020
SchedulerType: c.Diffusers.SchedulerType,
2121
PipelineType: c.Diffusers.PipelineType,
2222
CFGScale: c.Diffusers.CFGScale,

api/backend/options.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
4646
Seed: int32(c.Seed),
4747
NBatch: int32(b),
4848
NoMulMatQ: c.NoMulMatQ,
49+
CUDA: c.CUDA, // diffusers, transformers
4950
DraftModel: c.DraftModel,
5051
AudioPath: c.VallE.AudioPath,
5152
Quantization: c.Quantization,

api/config/config.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ type Config struct {
4646

4747
// Vall-e-x
4848
VallE VallE `yaml:"vall-e"`
49+
50+
// CUDA
51+
// Explicitly enable CUDA or not (some backends might need it)
52+
CUDA bool `yaml:"cuda"`
4953
}
5054

5155
type VallE struct {
@@ -67,7 +71,6 @@ type GRPC struct {
6771
type Diffusers struct {
6872
PipelineType string `yaml:"pipeline_type"`
6973
SchedulerType string `yaml:"scheduler_type"`
70-
CUDA bool `yaml:"cuda"`
7174
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
7275
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
7376
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser

backend/python/transformers/test_transformers.py renamed to backend/python/transformers/test_transformers_server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_server_startup(self):
3131
"""
3232
This method tests if the server starts up successfully
3333
"""
34-
time.sleep(2)
34+
time.sleep(10)
3535
try:
3636
self.setUp()
3737
with grpc.insecure_channel("localhost:50051") as channel:
@@ -48,11 +48,12 @@ def test_load_model(self):
4848
"""
4949
This method tests if the model is loaded successfully
5050
"""
51+
time.sleep(10)
5152
try:
5253
self.setUp()
5354
with grpc.insecure_channel("localhost:50051") as channel:
5455
stub = backend_pb2_grpc.BackendStub(channel)
55-
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
56+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
5657
self.assertTrue(response.success)
5758
self.assertEqual(response.message, "Model loaded successfully")
5859
except Exception as err:
@@ -65,11 +66,13 @@ def test_embedding(self):
6566
"""
6667
This method tests if the embeddings are generated successfully
6768
"""
69+
time.sleep(10)
6870
try:
6971
self.setUp()
7072
with grpc.insecure_channel("localhost:50051") as channel:
7173
stub = backend_pb2_grpc.BackendStub(channel)
72-
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
74+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
75+
print(response.message)
7376
self.assertTrue(response.success)
7477
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
7578
embedding_response = stub.Embedding(embedding_request)

backend/python/transformers/transformers_server.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,27 @@
1414
import backend_pb2_grpc
1515

1616
import grpc
17+
import torch
1718

18-
from transformers import AutoModel
19+
from transformers import AutoTokenizer, AutoModel
1920

2021
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
2122

2223
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
2324
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
2425

26+
27+
def mean_pooling(model_output, attention_mask):
28+
"""
29+
Mean pooling to get sentence embeddings. See:
30+
https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1
31+
"""
32+
token_embeddings = model_output[0]
33+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns
35+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
36+
return sum_embeddings / sum_mask
37+
2538
# Implement the BackendServicer class with the service methods
2639
class BackendServicer(backend_pb2_grpc.BackendServicer):
2740
"""
@@ -56,9 +69,19 @@ def LoadModel(self, request, context):
5669
model_name = request.Model
5770
try:
5871
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
72+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
73+
74+
if request.CUDA:
75+
try:
76+
# TODO: also tensorflow, make configurable
77+
import torch.cuda
78+
if torch.cuda.is_available():
79+
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
80+
self.model = self.model.to("cuda")
81+
except Exception as err:
82+
print("Not using CUDA:", err, file=sys.stderr)
5983
except Exception as err:
6084
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
61-
6285
# Implement your logic here for the LoadModel service
6386
# Replace this with your desired response
6487
return backend_pb2.Result(message="Model loaded successfully", success=True)
@@ -74,10 +97,20 @@ def Embedding(self, request, context):
7497
Returns:
7598
An EmbeddingResult object that contains the calculated embeddings.
7699
"""
77-
# Implement your logic here for the Embedding service
78-
# Replace this with your desired response
100+
101+
# Tokenize input
102+
max_length = 512
103+
if request.Tokens != 0:
104+
max_length = request.Tokens
105+
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
106+
107+
# Create word embeddings
108+
model_output = self.model(**encoded_input)
109+
110+
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
111+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
79112
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
80-
sentence_embeddings = self.model.encode(request.Embeddings)
113+
print("Embeddings:", sentence_embeddings, file=sys.stderr)
81114
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
82115

83116

docs/content/advanced/_index.en.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ lora_adapter: "/path/to/lora/adapter"
207207
lora_base: "/path/to/lora/base"
208208
# Disable mulmatq (CUDA)
209209
no_mulmatq: true
210+
211+
# Diffusers/transformers
212+
cuda: true
210213
```
211214

212215
### Prompt templates
@@ -363,4 +366,32 @@ You can control the backends that are built by setting the `GRPC_BACKENDS` envir
363366
make GRPC_BACKENDS=backend-assets/grpc/llama-cpp build
364367
```
365368

366-
By default, all the backends are built.
369+
By default, all the backends are built.
370+
371+
### Extra backends
372+
373+
LocalAI can be extended with extra backends. The backends are implemented as `gRPC` services and can be written in any language. The container images that are built and published on [quay.io](https://quay.io/repository/go-skynet/local-ai?tab=tags) contain a set of images split in core and extra. By default Images bring all the dependencies and backends supported by LocalAI (we call those `extra` images). The `-core` images instead bring only the strictly necessary dependencies to run LocalAI without only a core set of backends.
374+
375+
If you wish to build a custom container image with extra backends, you can use the core images and build only the backends you are interested into. For instance, to use the diffusers backend:
376+
377+
```Dockerfile
378+
FROM quay.io/go-skynet/local-ai:master-ffmpeg-core
379+
380+
RUN PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers
381+
```
382+
383+
Remember also to set the `EXTERNAL_GRPC_BACKENDS` environment variable (or `--external-grpc-backends` as CLI flag) to point to the backends you are using (`EXTERNAL_GRPC_BACKENDS="backend_name:/path/to/backend"`), for example with diffusers:
384+
385+
```Dockerfile
386+
FROM quay.io/go-skynet/local-ai:master-ffmpeg-core
387+
388+
RUN PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers
389+
390+
ENV EXTERNAL_GRPC_BACKENDS="diffusers:/build/backend/python/diffusers/run.sh"
391+
```
392+
393+
{{% notice note %}}
394+
395+
You can specify remote external backends or path to local files. The syntax is `backend-name:/path/to/backend` or `backend-name:host:port`.
396+
397+
{{% /notice %}}

docs/content/getting_started/_index.en.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ You can control LocalAI with command line arguments, to specify a binding addres
178178
| --watchdog-busy-timeout value | $WATCHDOG_BUSY_TIMEOUT | 5m | Watchdog timeout. This will restart the backend if it crashes. |
179179
| --watchdog-idle-timeout value | $WATCHDOG_IDLE_TIMEOUT | 15m | Watchdog idle timeout. This will restart the backend if it crashes. |
180180
| --preload-backend-only | $PRELOAD_BACKEND_ONLY | false | If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups. |
181+
| --external-grpc-backends | EXTERNAL_GRPC_BACKENDS | none | Comma separated list of external gRPC backends to use. Format: `name:host:port` or `name:/path/to/file` |
181182

182183
### Container images
183184

0 commit comments

Comments
 (0)