zhaend commited on
Commit
e5ee7eb
·
verified ·
1 Parent(s): f9cf754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -146
app.py CHANGED
@@ -12,177 +12,101 @@ from src.pipeline import FluxPipeline
12
  from src.transformer_flux import FluxTransformer2DModel
13
  from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
 
15
- class ImageProcessor:
16
- def __init__(self, path):
17
- device = "cuda"
18
- self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)
19
- transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer", torch_dtype=torch.bfloat16, device=device)
20
- self.pipe.transformer = transformer
21
- self.pipe.to(device)
22
-
23
- def clear_cache(self, transformer):
24
- for name, attn_processor in transformer.attn_processors.items():
25
- attn_processor.bank_kv.clear()
26
-
27
- @spaces.GPU()
28
- def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height=768, width=768, output_path=None, seed=42):
29
- image = self.pipe(
30
- prompt,
31
- height=int(height),
32
- width=int(width),
33
- guidance_scale=3.5,
34
- num_inference_steps=25,
35
- max_sequence_length=512,
36
- generator=torch.Generator("cpu").manual_seed(seed),
37
- subject_images=subject_imgs,
38
- spatial_images=spatial_imgs,
39
- cond_size=512,
40
- ).images[0]
41
- self.clear_cache(self.pipe.transformer)
42
- if output_path:
43
- image.save(output_path)
44
- return image
45
 
46
  # Initialize the image processor
47
  base_path = "black-forest-labs/FLUX.1-dev"
48
- lora_base_path = "EasyControl/models"
49
- style_lora_base_path = "Shakker-Labs"
50
- processor = ImageProcessor(base_path)
51
 
52
- # Define the Gradio interface
53
- def single_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora=None):
54
- # Set the control type
55
- if control_type == "subject":
56
- lora_path = os.path.join(lora_base_path, "subject.safetensors")
57
- elif control_type == "depth":
58
- lora_path = os.path.join(lora_base_path, "depth.safetensors")
59
- elif control_type == "seg":
60
- lora_path = os.path.join(lora_base_path, "seg.safetensors")
61
- elif control_type == "pose":
62
- lora_path = os.path.join(lora_base_path, "pose.safetensors")
63
- elif control_type == "inpainting":
64
- lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
65
- elif control_type == "hedsketch":
66
- lora_path = os.path.join(lora_base_path, "hedsketch.safetensors")
67
- elif control_type == "canny":
68
- lora_path = os.path.join(lora_base_path, "canny.safetensors")
69
- set_single_lora(processor.pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
70
-
71
- # Set the style LoRA
72
- if style_lora=="None":
73
- pass
74
- else:
75
- if style_lora == "Simple_Sketch":
76
- processor.pipe.unload_lora_weights()
77
- style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
78
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
79
- if style_lora == "Text_Poster":
80
- processor.pipe.unload_lora_weights()
81
- style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
82
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
83
- if style_lora == "Vector_Style":
84
- processor.pipe.unload_lora_weights()
85
- style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
86
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
87
 
88
- # Process the image
89
- subject_imgs = [subject_img] if subject_img else []
90
- spatial_imgs = [spatial_img] if spatial_img else []
91
- image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
92
- return image
93
 
94
  # Define the Gradio interface
95
- def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
96
- subject_path = os.path.join(lora_base_path, "subject.safetensors")
97
- inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
98
- set_multi_lora(processor.pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
 
 
99
 
100
  # Process the image
101
- subject_imgs = [subject_img] if subject_img else []
102
  spatial_imgs = [spatial_img] if spatial_img else []
103
- image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
104
- return image
105
 
106
- # Define the Gradio interface components
107
- control_types = ["subject", "depth", "pose", "inpainting", "hedsketch", "seg", "canny"]
108
- style_loras = ["Simple_Sketch", "Text_Poster", "Vector_Style", "None"]
109
-
110
- # Example data
111
- single_examples = [
112
- ["A SKS in the library", Image.open("./test_imgs/subject1.png"), None, 1024, 1024, 5, "subject", None],
113
- ["In a picturesque village, a narrow cobblestone street with rustic stone buildings, colorful blinds, and lush green spaces, a cartoon man drawn with simple lines and solid colors stands in the foreground, wearing a red shirt, beige work pants, and brown shoes, carrying a strap on his shoulder. The scene features warm and enticing colors, a pleasant fusion of nature and architecture, and the camera's perspective on the street clearly shows the charming and quaint environment., Integrating elements of reality and cartoon.", None, Image.open("./test_imgs/spatial1.png"), 1024, 1024, 1, "pose", "Vector_Style"],
114
- ]
115
- multi_examples = [
116
- ["A SKS on the car", Image.open("./test_imgs/subject2.png"), Image.open("./test_imgs/spatial2.png"), 1024, 1024, 7],
117
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
119
 
120
  # Create the Gradio Blocks interface
121
  with gr.Blocks() as demo:
122
- gr.Markdown("# Image Generation with EasyControl")
123
- gr.Markdown("Generate images using EasyControl with different control types and style LoRAs.")
 
 
 
 
124
 
125
- with gr.Tab("Single Condition Generation"):
126
  with gr.Row():
127
  with gr.Column():
128
- prompt = gr.Textbox(label="Prompt")
129
- subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
130
- spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
131
- height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768)
132
- width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768)
133
  seed = gr.Number(label="Seed", value=42)
 
134
  control_type = gr.Dropdown(choices=control_types, label="Control Type")
135
- style_lora = gr.Dropdown(choices=style_loras, label="Style LoRA")
136
  single_generate_btn = gr.Button("Generate Image")
137
  with gr.Column():
138
- single_output_image = gr.Image(label="Generated Image")
139
-
140
- # Add examples for Single Condition Generation
141
- gr.Examples(
142
- examples=single_examples,
143
- inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
144
- outputs=single_output_image,
145
- fn=single_condition_generate_image,
146
- cache_examples=False, # 缓存示例结果以加快加载速度
147
- label="Single Condition Examples"
148
- )
149
-
150
-
151
- with gr.Tab("Multi-Condition Generation"):
152
- with gr.Row():
153
- with gr.Column():
154
- multi_prompt = gr.Textbox(label="Prompt")
155
- multi_subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
156
- multi_spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
157
- multi_height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768)
158
- multi_width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768)
159
- multi_seed = gr.Number(label="Seed", value=42)
160
- multi_generate_btn = gr.Button("Generate Image")
161
- with gr.Column():
162
- multi_output_image = gr.Image(label="Generated Image")
163
-
164
- # Add examples for Multi-Condition Generation
165
- gr.Examples(
166
- examples=multi_examples,
167
- inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
168
- outputs=multi_output_image,
169
- fn=multi_condition_generate_image,
170
- cache_examples=False, # 缓存示例结果以加快加载速度
171
- label="Multi-Condition Examples"
172
- )
173
-
174
 
175
  # Link the buttons to the functions
176
  single_generate_btn.click(
177
- single_condition_generate_image,
178
- inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
179
- outputs=single_output_image
180
- )
181
- multi_generate_btn.click(
182
- multi_condition_generate_image,
183
- inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
184
- outputs=multi_output_image
185
  )
186
 
187
  # Launch the Gradio app
188
- demo.queue().launch()
 
12
  from src.transformer_flux import FluxTransformer2DModel
13
  from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
 
15
+ from huggingface_hub import hf_hub_download
16
+ hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/Ghibli.safetensors", local_dir="./checkpoints/models/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Initialize the image processor
19
  base_path = "black-forest-labs/FLUX.1-dev"
20
+ lora_base_path = "checkpoints/models/models"
 
 
21
 
22
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
23
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
24
+ pipe.transformer = transformer
25
+ pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def clear_cache(transformer):
28
+ for name, attn_processor in transformer.attn_processors.items():
29
+ attn_processor.bank_kv.clear()
 
 
30
 
31
  # Define the Gradio interface
32
+ @spaces.GPU()
33
+ def dual_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, zero_steps):
34
+ # Set the control type
35
+ if control_type == "Ghibli":
36
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
37
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
38
 
39
  # Process the image
 
40
  spatial_imgs = [spatial_img] if spatial_img else []
 
 
41
 
42
+ # Image with use_zero_init=True
43
+ image_true = pipe(
44
+ prompt,
45
+ height=int(height),
46
+ width=int(width),
47
+ guidance_scale=3.5,
48
+ num_inference_steps=25,
49
+ max_sequence_length=512,
50
+ generator=torch.Generator("cpu").manual_seed(seed),
51
+ subject_images=[],
52
+ spatial_images=spatial_imgs,
53
+ cond_size=512,
54
+ use_zero_init=True,
55
+ zero_steps=int(zero_steps)
56
+ ).images[0]
57
+ clear_cache(pipe.transformer)
58
+
59
+ # Image with use_zero_init=False
60
+ image_false = pipe(
61
+ prompt,
62
+ height=int(height),
63
+ width=int(width),
64
+ guidance_scale=3.5,
65
+ num_inference_steps=25,
66
+ max_sequence_length=512,
67
+ generator=torch.Generator("cpu").manual_seed(seed),
68
+ subject_images=[],
69
+ spatial_images=spatial_imgs,
70
+ cond_size=512,
71
+ use_zero_init=False
72
+ ).images[0]
73
+ clear_cache(pipe.transformer)
74
+
75
+ return image_true, image_false
76
 
77
+ # Define the Gradio interface components
78
+ control_types = ["Ghibli"]
79
 
80
  # Create the Gradio Blocks interface
81
  with gr.Blocks() as demo:
82
+ gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
83
+ gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
84
+ gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
85
+
86
+ gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: Ghibli Studio style, Charming hand-drawn anime-style illustration")
87
+ gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
88
 
89
+ with gr.Tab("Ghibli Condition Generation"):
90
  with gr.Row():
91
  with gr.Column():
92
+ prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
93
+ spatial_img = gr.Image(label="Ghibli Image", type="pil")
94
+ height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
95
+ width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
 
96
  seed = gr.Number(label="Seed", value=42)
97
+ zero_steps = gr.Number(label="Zero Init Steps", value=1)
98
  control_type = gr.Dropdown(choices=control_types, label="Control Type")
 
99
  single_generate_btn = gr.Button("Generate Image")
100
  with gr.Column():
101
+ image_with_zero_init = gr.Image(label="Image CFG-Zero*")
102
+ image_without_zero_init = gr.Image(label="Image CFG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Link the buttons to the functions
105
  single_generate_btn.click(
106
+ dual_condition_generate_image,
107
+ inputs=[prompt, spatial_img, height, width, seed, control_type, zero_steps],
108
+ outputs=[image_with_zero_init, image_without_zero_init]
 
 
 
 
 
109
  )
110
 
111
  # Launch the Gradio app
112
+ demo.queue().launch()