Skip to content

Commit 55472bb

Browse files
committed
feat: add tiny dream stable diffusion support
Signed-off-by: Gianluca Boiano <[email protected]>
1 parent 6d187af commit 55472bb

File tree

15 files changed

+214
-36
lines changed

15 files changed

+214
-36
lines changed

.github/workflows/test.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ jobs:
1818
runs-on: ubuntu-latest
1919
strategy:
2020
matrix:
21+
gcc-version: ['13']
2122
go-version: ['1.21.x']
23+
env:
24+
CC: gcc-${{ matrix.gcc-version }}
25+
CXX: g++-${{ matrix.gcc-version }}
2226
steps:
2327
- name: Release space from worker
2428
run: |
@@ -54,7 +58,7 @@ jobs:
5458
df -h
5559
- name: Clone
5660
uses: actions/checkout@v4
57-
with:
61+
with:
5862
submodules: true
5963
- name: Setup Go ${{ matrix.go-version }}
6064
uses: actions/setup-go@v4
@@ -74,9 +78,9 @@ jobs:
7478
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
7579
sudo apt-get update && \
7680
sudo apt-get install -y conda
77-
sudo apt-get install -y ca-certificates cmake curl patch
81+
sudo apt-get install -y ca-certificates cmake curl libgomp1 patch
7882
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
79-
83+
8084
sudo rm -rfv /usr/bin/conda || true
8185
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers
8286
@@ -93,7 +97,7 @@ jobs:
9397
../.. && sudo make -j12 install
9498
- name: Test
9599
run: |
96-
GO_TAGS="stablediffusion tts" make test
100+
GO_TAGS="stablediffusion tinydream tts" make test
97101
98102
tests-apple:
99103
runs-on: macOS-latest
@@ -103,7 +107,7 @@ jobs:
103107
steps:
104108
- name: Clone
105109
uses: actions/checkout@v4
106-
with:
110+
with:
107111
submodules: true
108112
- name: Setup Go ${{ matrix.go-version }}
109113
uses: actions/setup-go@v4
@@ -122,4 +126,4 @@ jobs:
122126
run: |
123127
export C_INCLUDE_PATH=/usr/local/include
124128
export CPLUS_INCLUDE_PATH=/usr/local/include
125-
CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test
129+
CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ ARG TARGETVARIANT
1414
ENV BUILD_TYPE=${BUILD_TYPE}
1515
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh"
1616
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
17-
ARG GO_TAGS="stablediffusion tts"
17+
ARG GO_TAGS="stablediffusion tinydream tts"
1818

1919
RUN apt-get update && \
2020
apt-get install -y ca-certificates curl patch pip cmake && apt-get clean

Makefile

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ PIPER_VERSION?=7fe05263b4ca3ffa93a53e2737643a6a6afb9a7b
3333
# stablediffusion version
3434
STABLEDIFFUSION_VERSION?=902db5f066fd137697e3b69d0fa10d4782bd2c2f
3535

36+
# tinydream version
37+
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
38+
3639
export BUILD_TYPE?=
3740
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
3841
export CMAKE_ARGS?=
@@ -122,6 +125,11 @@ ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion)
122125
OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion
123126
endif
124127

128+
ifeq ($(findstring tinydream,$(GO_TAGS)),tinydream)
129+
# OPTIONAL_TARGETS+=go-tiny-dream/libtinydream.a
130+
OPTIONAL_GRPC+=backend-assets/grpc/tinydream
131+
endif
132+
125133
ifeq ($(findstring tts,$(GO_TAGS)),tts)
126134
# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
127135
# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
@@ -165,6 +173,14 @@ sources/go-stable-diffusion:
165173
sources/go-stable-diffusion/libstablediffusion.a:
166174
$(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
167175

176+
## tiny-dream
177+
sources/go-tiny-dream:
178+
git clone --recurse-submodules https://github.com/M0Rf30/go-tiny-dream sources/go-tiny-dream
179+
cd sources/go-tiny-dream && git checkout -b build $(TINYDREAM_VERSION) && git submodule update --init --recursive --depth 1
180+
181+
sources/go-tiny-dream/libtinydream.a:
182+
$(MAKE) -C sources/go-tiny-dream libtinydream.a
183+
168184
## RWKV
169185
sources/go-rwkv:
170186
git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv
@@ -225,7 +241,7 @@ sources/go-piper/libpiper_binding.a: sources/go-piper
225241
backend/cpp/llama/llama.cpp:
226242
$(MAKE) -C backend/cpp/llama llama.cpp
227243

228-
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/go-ggml-transformers sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion
244+
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/go-ggml-transformers sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
229245
touch $@
230246

231247
replace:
@@ -235,6 +251,7 @@ replace:
235251
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/sources/whisper.cpp
236252
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/sources/go-bert
237253
$(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/sources/go-stable-diffusion
254+
$(GOCMD) mod edit -replace github.com/M0Rf30/go-tiny-dream=$(shell pwd)/sources/go-tiny-dream
238255
$(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/sources/go-piper
239256

240257
prepare-sources: get-sources replace
@@ -253,6 +270,7 @@ rebuild: ## Rebuilds the project
253270
$(MAKE) -C sources/go-stable-diffusion clean
254271
$(MAKE) -C sources/go-bert clean
255272
$(MAKE) -C sources/go-piper clean
273+
$(MAKE) -C sources/go-tiny-dream clean
256274
$(MAKE) build
257275

258276
prepare: prepare-sources $(OPTIONAL_TARGETS)
@@ -316,6 +334,7 @@ test: prepare test-models/testmodel grpcs
316334
$(MAKE) test-llama-gguf
317335
$(MAKE) test-tts
318336
$(MAKE) test-stablediffusion
337+
$(MAKE) test-tinydream
319338

320339
prepare-e2e:
321340
mkdir -p $(TEST_DIR)
@@ -357,6 +376,10 @@ test-stablediffusion: prepare-test
357376
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
358377
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
359378

379+
test-tinydream: prepare-test
380+
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
381+
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tinydream" --flake-attempts 1 -v -r ./api ./pkg
382+
360383
test-container:
361384
docker build --target requirements -t local-ai-test-container .
362385
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
@@ -501,9 +524,13 @@ backend-assets/grpc/stablediffusion: backend-assets/grpc
501524
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
502525
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
503526
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/sources/go-stable-diffusion/ \
504-
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/; \
527+
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
505528
fi
506529

530+
backend-assets/grpc/tinydream: backend-assets/grpc sources/go-tiny-dream/libtinydream.a
531+
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-tiny-dream \
532+
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
533+
507534
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data sources/go-piper/libpiper_binding.a
508535
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/sources/go-piper \
509536
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/

api/api_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,43 @@ var _ = Describe("API test", func() {
548548
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
549549
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
550550
})
551+
It("installs and is capable to generate images", Label("tinydream"), func() {
552+
if runtime.GOOS != "linux" {
553+
Skip("test supported only on linux")
554+
}
555+
556+
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
557+
ID: "model-gallery@tinydream",
558+
Overrides: map[string]interface{}{
559+
"parameters": map[string]interface{}{"model": "tinydream_assets"},
560+
},
561+
})
562+
563+
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
564+
565+
uuid := response["uuid"].(string)
566+
567+
Eventually(func() bool {
568+
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
569+
fmt.Println(response)
570+
return response["processed"].(bool)
571+
}, "360s", "10s").Should(Equal(true))
572+
573+
resp, err := http.Post(
574+
"http://127.0.0.1:9090/v1/images/generations",
575+
"application/json",
576+
bytes.NewBuffer([]byte(`{
577+
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
578+
"seed":9000,
579+
"size": "256x256"}`)))
580+
// The response should contain an URL
581+
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
582+
dat, err := io.ReadAll(resp.Body)
583+
Expect(err).ToNot(HaveOccurred(), string(dat))
584+
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
585+
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
586+
587+
})
551588
It("installs and is capable to generate images", Label("stablediffusion"), func() {
552589
if runtime.GOOS != "linux" {
553590
Skip("test supported only on linux")

api/openai/image.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
4545
}
4646

4747
if m == "" {
48-
m = model.StableDiffusionBackend
48+
m = model.TinyDreamBackend
4949
}
5050
log.Debug().Msgf("Loading model: %+v", m)
5151

@@ -82,8 +82,13 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
8282
log.Debug().Msgf("Parameter Config: %+v", config)
8383

8484
// XXX: Only stablediffusion is supported for now
85-
if config.Backend == "" {
85+
switch config.Backend {
86+
case "stablediffusion":
8687
config.Backend = model.StableDiffusionBackend
88+
case "tinydream":
89+
config.Backend = model.TinyDreamBackend
90+
default:
91+
config.Backend = model.TinyDreamBackend
8792
}
8893

8994
sizeParts := strings.Split(input.Size, "x")

backend/go/image/main.go renamed to backend/go/image/stablediffusion/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ var (
1515
func main() {
1616
flag.Parse()
1717

18-
if err := grpc.StartServer(*addr, &StableDiffusion{}); err != nil {
18+
if err := grpc.StartServer(*addr, &Image{}); err != nil {
1919
panic(err)
2020
}
2121
}

backend/go/image/stablediffusion.go renamed to backend/go/image/stablediffusion/stablediffusion.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ import (
88
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
99
)
1010

11-
type StableDiffusion struct {
11+
type Image struct {
1212
base.SingleThread
1313
stablediffusion *stablediffusion.StableDiffusion
1414
}
1515

16-
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error {
16+
func (image *Image) Load(opts *pb.ModelOptions) error {
1717
var err error
1818
// Note: the Model here is a path to a directory containing the model files
19-
sd.stablediffusion, err = stablediffusion.New(opts.ModelFile)
19+
image.stablediffusion, err = stablediffusion.New(opts.ModelFile)
2020
return err
2121
}
2222

23-
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error {
24-
return sd.stablediffusion.GenerateImage(
23+
func (image *Image) GenerateImage(opts *pb.GenerateImageRequest) error {
24+
return image.stablediffusion.GenerateImage(
2525
int(opts.Height),
2626
int(opts.Width),
2727
int(opts.Mode),

backend/go/image/tinydream/main.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package main
2+
3+
// Note: this is started internally by LocalAI and a server is allocated for each model
4+
5+
import (
6+
"flag"
7+
8+
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
9+
)
10+
11+
var (
12+
addr = flag.String("addr", "localhost:50051", "the address to connect to")
13+
)
14+
15+
func main() {
16+
flag.Parse()
17+
18+
if err := grpc.StartServer(*addr, &Image{}); err != nil {
19+
panic(err)
20+
}
21+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package main
2+
3+
// This is a wrapper to statisfy the GRPC service interface
4+
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
5+
import (
6+
"github.com/go-skynet/LocalAI/pkg/grpc/base"
7+
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
8+
"github.com/go-skynet/LocalAI/pkg/tinydream"
9+
)
10+
11+
type Image struct {
12+
base.SingleThread
13+
tinydream *tinydream.TinyDream
14+
}
15+
16+
func (image *Image) Load(opts *pb.ModelOptions) error {
17+
var err error
18+
// Note: the Model here is a path to a directory containing the model files
19+
image.tinydream, err = tinydream.New(opts.ModelFile)
20+
return err
21+
}
22+
23+
func (image *Image) GenerateImage(opts *pb.GenerateImageRequest) error {
24+
return image.tinydream.GenerateImage(
25+
int(opts.Height),
26+
int(opts.Width),
27+
int(opts.Step),
28+
int(opts.Seed),
29+
opts.PositivePrompt,
30+
opts.NegativePrompt,
31+
opts.Dst)
32+
}

go.mod

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ module github.com/go-skynet/LocalAI
33
go 1.21
44

55
require (
6+
github.com/M0Rf30/go-tiny-dream v0.0.0-20231120185742-d617ddbd38e8
67
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
78
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
89
github.com/go-audio/wav v1.1.0
9-
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa
1010
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
1111
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e
1212
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428
@@ -17,7 +17,6 @@ require (
1717
github.com/imdario/mergo v0.3.16
1818
github.com/json-iterator/go v1.1.12
1919
github.com/mholt/archiver/v3 v3.5.1
20-
github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef
2120
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
2221
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
2322
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530

0 commit comments

Comments
 (0)