diff --git a/train_flux_reference.py b/train_flux_reference.py index 8d15102..3d626af 100644 --- a/train_flux_reference.py +++ b/train_flux_reference.py @@ -32,6 +32,7 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe = FluxImagePipeline.from_model_manager(model_manager) self.pipe.reference_embedder = FluxReferenceEmbedder() + self.pipe.reference_embedder.init() if quantize is not None: self.pipe.dit.quantize()