mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
train
This commit is contained in:
@@ -1181,13 +1181,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
decoder_layer.__call__,
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
@@ -1196,7 +1201,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
|
# layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
# decoder_layer.__call__,
|
||||||
|
# hidden_states,
|
||||||
|
# causal_mask,
|
||||||
|
# position_ids,
|
||||||
|
# past_key_values,
|
||||||
|
# output_attentions,
|
||||||
|
# use_cache,
|
||||||
|
# cache_position,
|
||||||
|
# position_embeddings,
|
||||||
|
# )
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
torchvision
|
torchvision
|
||||||
cupy-cuda12x
|
cupy-cuda12x
|
||||||
transformers==4.46.2
|
transformers==4.49.0
|
||||||
controlnet-aux==0.0.7
|
controlnet-aux==0.0.7
|
||||||
imageio
|
imageio
|
||||||
imageio[ffmpeg]
|
imageio[ffmpeg]
|
||||||
@@ -11,3 +11,4 @@ sentencepiece
|
|||||||
protobuf
|
protobuf
|
||||||
modelscope
|
modelscope
|
||||||
ftfy
|
ftfy
|
||||||
|
qwen_vl_utils
|
||||||
|
|||||||
32
test.py
32
test.py
@@ -102,31 +102,35 @@ model_manager.load_models([
|
|||||||
])
|
])
|
||||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
|
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
|
||||||
pipe.dit.load_state_dict(state_dict, strict=False)
|
# pipe.dit.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
|
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
|
||||||
adapter.load_state_dict(state_dict, strict=False)
|
# adapter.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
|
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
|
||||||
|
|
||||||
|
sd = {}
|
||||||
|
for i in range(1, 6):
|
||||||
|
print(i)
|
||||||
|
sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
|
||||||
|
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
|
||||||
|
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
|
||||||
|
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
|
||||||
|
for i in sd:
|
||||||
|
if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")):
|
||||||
|
print(i)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
instruction = "Generate an image according to the following description: a beautiful Asian girl"
|
instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
|
||||||
emb = qwenvl(instruction, images=None)
|
emb = qwenvl(instruction, images=None)
|
||||||
emb = adapter(emb)
|
emb = adapter(emb)
|
||||||
image = pipe("", image_emb=emb)
|
image = pipe("", image_emb=emb, height=512, width=512)
|
||||||
image.save("image_1.jpg")
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses."
|
instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
|
||||||
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
|
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
|
||||||
emb = adapter(emb)
|
emb = adapter(emb)
|
||||||
image = pipe("", image_emb=emb)
|
image = pipe("", image_emb=emb, height=512, width=512)
|
||||||
image.save("image_2.jpg")
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile."
|
|
||||||
emb = qwenvl(instruction, images=[Image.open("image_2.jpg")])
|
|
||||||
emb = adapter(emb)
|
|
||||||
image = pipe("", image_emb=emb)
|
|
||||||
image.save("image_3.jpg")
|
|
||||||
|
|||||||
12
train.py
12
train.py
@@ -89,8 +89,6 @@ class SingleTaskDataset(torch.utils.data.Dataset):
|
|||||||
def load_image(self, image_path, skip_process=False):
|
def load_image(self, image_path, skip_process=False):
|
||||||
image_path = os.path.join(self.base_path, image_path)
|
image_path = os.path.join(self.base_path, image_path)
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
if skip_process:
|
|
||||||
return image
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
scale = max(self.width / width, self.height / height)
|
scale = max(self.width / width, self.height / height)
|
||||||
image = torchvision.transforms.functional.resize(
|
image = torchvision.transforms.functional.resize(
|
||||||
@@ -98,6 +96,8 @@ class SingleTaskDataset(torch.utils.data.Dataset):
|
|||||||
(round(height*scale), round(width*scale)),
|
(round(height*scale), round(width*scale)),
|
||||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
)
|
)
|
||||||
|
if skip_process:
|
||||||
|
return image
|
||||||
image = self.image_process(image)
|
image = self.image_process(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -254,6 +254,10 @@ class UnifiedModel(pl.LightningModule):
|
|||||||
self.pipe.vae_decoder.requires_grad_(False)
|
self.pipe.vae_decoder.requires_grad_(False)
|
||||||
self.pipe.vae_encoder.requires_grad_(False)
|
self.pipe.vae_encoder.requires_grad_(False)
|
||||||
self.pipe.text_encoder_1.requires_grad_(False)
|
self.pipe.text_encoder_1.requires_grad_(False)
|
||||||
|
self.pipe.train()
|
||||||
|
self.adapter.train()
|
||||||
|
self.qwenvl.train()
|
||||||
|
# self.qwenvl.model.model.gradient_checkpointing = True
|
||||||
|
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
@@ -289,7 +293,7 @@ class UnifiedModel(pl.LightningModule):
|
|||||||
self.pipe.denoising_model(),
|
self.pipe.denoising_model(),
|
||||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
image_emb=emb,
|
image_emb=emb,
|
||||||
use_gradient_checkpointing=True
|
use_gradient_checkpointing=False
|
||||||
)
|
)
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
@@ -331,7 +335,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--steps_per_epoch",
|
"--steps_per_epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=1000,
|
||||||
help="steps_per_epoch",
|
help="steps_per_epoch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user