mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support training for eligen and nexusgen
This commit is contained in:
@@ -14,7 +14,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAutoregressiveModelStateDictConverter()
|
||||
@@ -34,6 +34,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
return messages
|
||||
|
||||
def get_generation_msg(self, instruction):
|
||||
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||
return messages
|
||||
|
||||
@@ -80,9 +81,10 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
)
|
||||
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||
|
||||
position_ids, _ = model.get_rope_index(inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids, _ = model.get_rope_index(
|
||||
inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||
@@ -97,4 +99,3 @@ class NexusGenAutoregressiveModelStateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@@ -767,9 +767,10 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
|
||||
inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
|
||||
@@ -120,8 +120,13 @@ class ImageDataset(torch.utils.data.Dataset):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if isinstance(data[key], list):
|
||||
print(f"Loading multiple files for key '{key}'.")
|
||||
path = [os.path.join(self.base_path, p) for p in data[key]]
|
||||
data[key] = [self.load_data(p) for p in path]
|
||||
else:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user