support training for eligen and nexusgen

This commit is contained in:
mi804
2025-07-29 13:28:42 +08:00
parent 2861ec4d9f
commit 8ef91b3672
14 changed files with 218 additions and 17 deletions

View File

@@ -14,7 +14,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
@staticmethod
def state_dict_converter():
return NexusGenAutoregressiveModelStateDictConverter()
@@ -34,6 +34,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
return messages
def get_generation_msg(self, instruction):
instruction = "Generate an image according to the following description: {}".format(instruction)
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
return messages
@@ -80,9 +81,10 @@ class NexusGenAutoregressiveModel(torch.nn.Module):
)
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
position_ids, _ = model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids, _ = model.get_rope_index(
inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
@@ -97,4 +99,3 @@ class NexusGenAutoregressiveModelStateDictConverter:
def from_civitai(self, state_dict):
state_dict = {"model." + key: value for key, value in state_dict.items()}
return state_dict