11using Microsoft . Extensions . Logging ;
2+ using Microsoft . ML . OnnxRuntime . Tensors ;
3+ using OnnxStack . Core ;
24using OnnxStack . Core . Model ;
35using OnnxStack . StableDiffusion . Common ;
46using OnnxStack . StableDiffusion . Config ;
57using OnnxStack . StableDiffusion . Diffusers . StableDiffusionXL ;
68using OnnxStack . StableDiffusion . Enums ;
79using OnnxStack . StableDiffusion . Models ;
810using OnnxStack . StableDiffusion . Schedulers . LatentConsistency ;
11+ using System . Diagnostics ;
12+ using System . Linq ;
13+ using System . Threading . Tasks ;
14+ using System . Threading ;
15+ using System ;
916
1017namespace OnnxStack . StableDiffusion . Diffusers . LatentConsistencyXL
1118{
@@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel
2936 public override DiffuserPipelineType PipelineType => DiffuserPipelineType . LatentConsistencyXL ;
3037
3138
39+ /// <summary>
40+ /// Runs the scheduler steps.
41+ /// </summary>
42+ /// <param name="modelOptions">The model options.</param>
43+ /// <param name="promptOptions">The prompt options.</param>
44+ /// <param name="schedulerOptions">The scheduler options.</param>
45+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
46+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
47+ /// <param name="progressCallback">The progress callback.</param>
48+ /// <param name="cancellationToken">The cancellation token.</param>
49+ /// <returns></returns>
50+ public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
51+ {
52+ // Get Scheduler
53+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
54+ {
55+ // Get timesteps
56+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
57+
58+ // Create latent sample
59+ var latents = await PrepareLatentsAsync ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
60+
61+ // Get Model metadata
62+ var metadata = await _unet . GetMetadataAsync ( ) ;
63+
64+ // Get Time ids
65+ var addTimeIds = GetAddTimeIds ( schedulerOptions ) ;
66+
67+ // Get Guidance Scale Embedding
68+ var guidanceEmbeddings = GetGuidanceScaleEmbedding ( schedulerOptions . GuidanceScale ) ;
69+
70+ // Loop though the timesteps
71+ var step = 0 ;
72+ foreach ( var timestep in timesteps )
73+ {
74+ step ++ ;
75+ var stepTime = Stopwatch . GetTimestamp ( ) ;
76+ cancellationToken . ThrowIfCancellationRequested ( ) ;
77+
78+ // Create input tensor.
79+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
80+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
81+ var timestepTensor = CreateTimestepTensor ( timestep ) ;
82+ var timeids = performGuidance ? addTimeIds . Repeat ( 2 ) : addTimeIds ;
83+
84+ var outputChannels = performGuidance ? 2 : 1 ;
85+ var outputDimension = schedulerOptions . GetScaledDimension ( outputChannels ) ;
86+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
87+ {
88+ inferenceParameters . AddInputTensor ( inputTensor ) ;
89+ inferenceParameters . AddInputTensor ( timestepTensor ) ;
90+ inferenceParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
91+ if ( inferenceParameters . InputCount == 6 )
92+ inferenceParameters . AddInputTensor ( guidanceEmbeddings ) ;
93+ inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
94+ inferenceParameters . AddInputTensor ( timeids ) ;
95+ inferenceParameters . AddOutputBuffer ( outputDimension ) ;
96+
97+ var results = await _unet . RunInferenceAsync ( inferenceParameters ) ;
98+ using ( var result = results . First ( ) )
99+ {
100+ var noisePred = result . ToDenseTensor ( ) ;
101+
102+ // Perform guidance
103+ if ( performGuidance )
104+ noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
105+
106+ // Scheduler Step
107+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
108+ }
109+ }
110+
111+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
112+ _logger ? . LogEnd ( LogLevel . Debug , $ "Step { step } /{ timesteps . Count } ", stepTime ) ;
113+ }
114+
115+ // Unload if required
116+ if ( _memoryMode == MemoryModeType . Minimum )
117+ await _unet . UnloadAsync ( ) ;
118+
119+ // Decode Latents
120+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
121+ }
122+ }
123+
124+
32125 /// <summary>
33126 /// Gets the scheduler.
34127 /// </summary>
@@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
42135 _ => default
43136 } ;
44137 }
138+
139+
140+ /// <summary>
141+ /// Gets the guidance scale embedding.
142+ /// </summary>
143+ /// <param name="options">The options.</param>
144+ /// <param name="embeddingDim">The embedding dim.</param>
145+ /// <returns></returns>
146+ protected DenseTensor < float > GetGuidanceScaleEmbedding ( float guidance , int embeddingDim = 256 )
147+ {
148+ var scale = ( guidance - 1f ) * 1000.0f ;
149+ var halfDim = embeddingDim / 2 ;
150+ float log = MathF . Log ( 10000.0f ) / ( halfDim - 1 ) ;
151+ var emb = Enumerable . Range ( 0 , halfDim )
152+ . Select ( x => scale * MathF . Exp ( - log * x ) )
153+ . ToArray ( ) ;
154+ var embSin = emb . Select ( MathF . Sin ) ;
155+ var embCos = emb . Select ( MathF . Cos ) ;
156+ var guidanceEmbedding = embSin . Concat ( embCos ) . ToArray ( ) ;
157+ return new DenseTensor < float > ( guidanceEmbedding , new [ ] { 1 , embeddingDim } ) ;
158+ }
45159 }
46160}
0 commit comments