-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Describe the bug
It seems that Hyper-FLUX.1-dev-8steps-lora can not support Flux-dev-fp8, the image seems the same when I load or not load Hyper-FLUX.1-dev-8steps-lora.
These are my code, Can any one use Hyper-FLUX.1-dev-8steps-lora on Flux-dev-fp8
self.transformer = FluxTransformer2DModel.from_single_file(os.path.join(self.model_root, self.config["transformer_path"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.transformer, weights=qfloat8)
freeze(self.transformer)
self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(self.model_root, self.config["text_encoder_2_repo"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.text_encoder_2, weights=qfloat8)
freeze(self.text_encoder_2)
self.pipe = FluxPipeline.from_pretrained(os.path.join(self.model_root, self.config["flux_repo"]), transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to(self.device)
self.pipe.transformer = self.transformer
self.pipe.text_encoder_2 = self.text_encoder_2
self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
self.pipe.fuse_lora(lora_scale=1.0)
Reproduction
It seems that Hyper-FLUX.1-dev-8steps-lora can not support Flux-dev-fp8, the image seems the same when I load or not load Hyper-FLUX.1-dev-8steps-lora.
These are my code, Can any one use Hyper-FLUX.1-dev-8steps-lora on Flux-dev-fp8
self.transformer = FluxTransformer2DModel.from_single_file(os.path.join(self.model_root, self.config["transformer_path"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.transformer, weights=qfloat8)
freeze(self.transformer)
self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(self.model_root, self.config["text_encoder_2_repo"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.text_encoder_2, weights=qfloat8)
freeze(self.text_encoder_2)
self.pipe = FluxPipeline.from_pretrained(os.path.join(self.model_root, self.config["flux_repo"]), transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to(self.device)
self.pipe.transformer = self.transformer
self.pipe.text_encoder_2 = self.text_encoder_2
self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
self.pipe.fuse_lora(lora_scale=1.0)
Logs
No response
System Info
It seems that Hyper-FLUX.1-dev-8steps-lora can not support Flux-dev-fp8, the image seems the same when I load or not load Hyper-FLUX.1-dev-8steps-lora.
These are my code, Can any one use Hyper-FLUX.1-dev-8steps-lora on Flux-dev-fp8
self.transformer = FluxTransformer2DModel.from_single_file(os.path.join(self.model_root, self.config["transformer_path"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.transformer, weights=qfloat8)
freeze(self.transformer)
self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(self.model_root, self.config["text_encoder_2_repo"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.text_encoder_2, weights=qfloat8)
freeze(self.text_encoder_2)
self.pipe = FluxPipeline.from_pretrained(os.path.join(self.model_root, self.config["flux_repo"]), transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to(self.device)
self.pipe.transformer = self.transformer
self.pipe.text_encoder_2 = self.text_encoder_2
self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
self.pipe.fuse_lora(lora_scale=1.0)
Who can help?
It seems that Hyper-FLUX.1-dev-8steps-lora can not support Flux-dev-fp8, the image seems the same when I load or not load Hyper-FLUX.1-dev-8steps-lora.
These are my code, Can any one use Hyper-FLUX.1-dev-8steps-lora on Flux-dev-fp8
self.transformer = FluxTransformer2DModel.from_single_file(os.path.join(self.model_root, self.config["transformer_path"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.transformer, weights=qfloat8)
freeze(self.transformer)
self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(self.model_root, self.config["text_encoder_2_repo"]), torch_dtype=torch.bfloat16).to(self.device)
quantize(self.text_encoder_2, weights=qfloat8)
freeze(self.text_encoder_2)
self.pipe = FluxPipeline.from_pretrained(os.path.join(self.model_root, self.config["flux_repo"]), transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to(self.device)
self.pipe.transformer = self.transformer
self.pipe.text_encoder_2 = self.text_encoder_2
self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
self.pipe.fuse_lora(lora_scale=1.0)