support training for eligen and nexusgen

This commit is contained in:
mi804
2025-07-29 13:28:42 +08:00
parent 2861ec4d9f
commit 8ef91b3672
14 changed files with 218 additions and 17 deletions

View File

@@ -14,7 +14,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
@staticmethod
def state_dict_converter():
return NexusGenAutoregressiveModelStateDictConverter()
@@ -34,6 +34,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
return messages
def get_generation_msg(self, instruction):
instruction = "Generate an image according to the following description: {}".format(instruction)
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
return messages
@@ -80,9 +81,10 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
)
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
position_ids, _ = model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids, _ = model.get_rope_index(
inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
@@ -97,4 +99,3 @@ class NexusGenAutoregressiveModelStateDictConverter:
def from_civitai(self, state_dict):
state_dict = {"model." + key: value for key, value in state_dict.items()}
return state_dict

View File

@@ -767,9 +767,10 @@ class FluxImageUnit_EntityControl(PipelineUnit):
if eligen_entity_prompts is None or eligen_entity_masks is None:
return inputs_shared, inputs_posi, inputs_nega
pipe.load_models_to_device(self.onload_model_names)
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"])
inputs_posi.update(eligen_kwargs_posi)
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
inputs_nega.update(eligen_kwargs_nega)

View File

@@ -120,8 +120,13 @@ class ImageDataset(torch.utils.data.Dataset):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if isinstance(data[key], list):
print(f"Loading multiple files for key '{key}'.")
path = [os.path.join(self.base_path, p) for p in data[key]]
data[key] = [self.load_data(p) for p in path]
else:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None