update convert_safetensors.py
This commit is contained in:
		
							parent
							
								
									f7227cd1c1
								
							
						
					
					
						commit
						053a08f5b7
					
				
							
								
								
									
										20
									
								
								backend-python/convert_safetensors.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										20
									
								
								backend-python/convert_safetensors.py
									
									
									
									
										vendored
									
									
								
							@ -25,7 +25,7 @@ def rename_key(rename, name):
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
 | 
			
		||||
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
 | 
			
		||||
    loaded = torch.load(pt_filename, map_location="cpu")
 | 
			
		||||
    if "state_dict" in loaded:
 | 
			
		||||
        loaded = loaded["state_dict"]
 | 
			
		||||
@ -34,12 +34,14 @@ def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename=
 | 
			
		||||
    # for k, v in loaded.items():
 | 
			
		||||
    #     print(f'{k}\t{v.shape}\t{v.dtype}')
 | 
			
		||||
 | 
			
		||||
    loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
 | 
			
		||||
    # For tensors to be contiguous
 | 
			
		||||
    for k, v in loaded.items():
 | 
			
		||||
        for transpose_name in transpose_names:
 | 
			
		||||
            if transpose_name in k:
 | 
			
		||||
                loaded[k] = v.transpose(0, 1)
 | 
			
		||||
    loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
 | 
			
		||||
 | 
			
		||||
    loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()}
 | 
			
		||||
 | 
			
		||||
    for k, v in loaded.items():
 | 
			
		||||
        print(f"{k}\t{v.shape}\t{v.dtype}")
 | 
			
		||||
@ -60,8 +62,18 @@ if __name__ == "__main__":
 | 
			
		||||
        convert_file(
 | 
			
		||||
            args.input,
 | 
			
		||||
            args.output,
 | 
			
		||||
            ["lora_A"],
 | 
			
		||||
            {"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"},
 | 
			
		||||
            rename={
 | 
			
		||||
                "time_faaaa": "time_first",
 | 
			
		||||
                "time_maa": "time_mix",
 | 
			
		||||
                "lora_A": "lora.0",
 | 
			
		||||
                "lora_B": "lora.1",
 | 
			
		||||
            },
 | 
			
		||||
            transpose_names=[
 | 
			
		||||
                "time_mix_w1",
 | 
			
		||||
                "time_mix_w2",
 | 
			
		||||
                "time_decay_w1",
 | 
			
		||||
                "time_decay_w2",
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        print(f"Saved to {args.output}")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user