Heekyung commited on
Commit
425c558
Β·
1 Parent(s): ee0cb7c

Delete visualizer_drag_gradio.py

Browse files
Files changed (1) hide show
  1. visualizer_drag_gradio.py +0 -940
visualizer_drag_gradio.py DELETED
@@ -1,940 +0,0 @@
1
- # https://huggingface.co/DragGan/DragGan-Models
2
- # https://arxiv.org/abs/2305.10973
3
- import os
4
- import os.path as osp
5
- from argparse import ArgumentParser
6
- from functools import partial
7
- from pathlib import Path
8
- import time
9
-
10
- import psutil
11
-
12
- import gradio as gr
13
- import numpy as np
14
- import torch
15
- from PIL import Image
16
-
17
- import dnnlib
18
- from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
19
- get_latest_points_pair, get_valid_mask,
20
- on_change_single_global_state)
21
- from viz.renderer import Renderer, add_watermark_np
22
-
23
-
24
- # download models from Hugging Face hub
25
- from huggingface_hub import snapshot_download
26
-
27
- model_dir = Path('./checkpoints')
28
- snapshot_download('DragGan/DragGan-Models',
29
- repo_type='model', local_dir=model_dir)
30
-
31
- parser = ArgumentParser()
32
- parser.add_argument('--share', action='store_true')
33
- parser.add_argument('--cache-dir', type=str, default='./checkpoints')
34
- args = parser.parse_args()
35
-
36
- cache_dir = args.cache_dir
37
-
38
- device = 'cuda'
39
- IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '')
40
- TIMEOUT = 80
41
-
42
-
43
- def reverse_point_pairs(points):
44
- new_points = []
45
- for p in points:
46
- new_points.append([p[1], p[0]])
47
- return new_points
48
-
49
-
50
- def clear_state(global_state, target=None):
51
- """Clear target history state from global_state
52
- If target is not defined, points and mask will be both removed.
53
- 1. set global_state['points'] as empty dict
54
- 2. set global_state['mask'] as full-one mask.
55
- """
56
- if target is None:
57
- target = ['point', 'mask']
58
- if not isinstance(target, list):
59
- target = [target]
60
- if 'point' in target:
61
- global_state['points'] = dict()
62
- print('Clear Points State!')
63
- if 'mask' in target:
64
- image_raw = global_state["images"]["image_raw"]
65
- global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
66
- dtype=np.uint8)
67
- print('Clear mask State!')
68
-
69
- return global_state
70
-
71
-
72
- def init_images(global_state):
73
- """This function is called only ones with Gradio App is started.
74
- 0. pre-process global_state, unpack value from global_state of need
75
- 1. Re-init renderer
76
- 2. run `renderer._render_drag_impl` with `is_drag=False` to generate
77
- new image
78
- 3. Assign images to global state and re-generate mask
79
- """
80
-
81
- if isinstance(global_state, gr.State):
82
- state = global_state.value
83
- else:
84
- state = global_state
85
-
86
- state['renderer'].init_network(
87
- state['generator_params'], # res
88
- valid_checkpoints_dict[state['pretrained_weight']], # pkl
89
- state['params']['seed'], # w0_seed,
90
- None, # w_load
91
- state['params']['latent_space'] == 'w+', # w_plus
92
- 'const',
93
- state['params']['trunc_psi'], # trunc_psi,
94
- state['params']['trunc_cutoff'], # trunc_cutoff,
95
- None, # input_transform
96
- state['params']['lr'] # lr,
97
- )
98
-
99
- state['renderer']._render_drag_impl(state['generator_params'],
100
- is_drag=False,
101
- to_pil=True)
102
-
103
- init_image = state['generator_params'].image
104
- state['images']['image_orig'] = init_image
105
- state['images']['image_raw'] = init_image
106
- state['images']['image_show'] = Image.fromarray(
107
- add_watermark_np(np.array(init_image)))
108
- state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
109
- dtype=np.uint8)
110
- return global_state
111
-
112
-
113
- def update_image_draw(image, points, mask, show_mask, global_state=None):
114
-
115
- image_draw = draw_points_on_image(image, points)
116
- if show_mask and mask is not None and not (mask == 0).all() and not (
117
- mask == 1).all():
118
- image_draw = draw_mask_on_image(image_draw, mask)
119
-
120
- image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
121
- if global_state is not None:
122
- global_state['images']['image_show'] = image_draw
123
- return image_draw
124
-
125
-
126
- def preprocess_mask_info(global_state, image):
127
- """Function to handle mask information.
128
- 1. last_mask is None: Do not need to change mask, return mask
129
- 2. last_mask is not None:
130
- 2.1 global_state is remove_mask:
131
- 2.2 global_state is add_mask:
132
- """
133
- if isinstance(image, dict):
134
- last_mask = get_valid_mask(image['mask'])
135
- else:
136
- last_mask = None
137
- mask = global_state['mask']
138
-
139
- # mask in global state is a placeholder with all 1.
140
- if (mask == 1).all():
141
- mask = last_mask
142
-
143
- # last_mask = global_state['last_mask']
144
- editing_mode = global_state['editing_state']
145
-
146
- if last_mask is None:
147
- return global_state
148
-
149
- if editing_mode == 'remove_mask':
150
- updated_mask = np.clip(mask - last_mask, 0, 1)
151
- print(f'Last editing_state is {editing_mode}, do remove.')
152
- elif editing_mode == 'add_mask':
153
- updated_mask = np.clip(mask + last_mask, 0, 1)
154
- print(f'Last editing_state is {editing_mode}, do add.')
155
- else:
156
- updated_mask = mask
157
- print(f'Last editing_state is {editing_mode}, '
158
- 'do nothing to mask.')
159
-
160
- global_state['mask'] = updated_mask
161
- # global_state['last_mask'] = None # clear buffer
162
- return global_state
163
-
164
-
165
- def print_memory_usage():
166
- # Print system memory usage
167
- print(f"System memory usage: {psutil.virtual_memory().percent}%")
168
-
169
- # Print GPU memory usage
170
- if torch.cuda.is_available():
171
- device = torch.device("cuda")
172
- print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB")
173
- print(
174
- f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB")
175
- device_properties = torch.cuda.get_device_properties(device)
176
- available_memory = device_properties.total_memory - \
177
- torch.cuda.max_memory_allocated()
178
- print(f"Available GPU memory: {available_memory / 1e9} GB")
179
- else:
180
- print("No GPU available")
181
-
182
-
183
- # filter large models running on SPACES
184
- allowed_checkpoints = [] # all checkpoints
185
- if IS_SPACE:
186
- allowed_checkpoints = ["stylegan_human_v2_512.pkl",
187
- "stylegan2_dogs_1024_pytorch.pkl"]
188
-
189
- valid_checkpoints_dict = {
190
- f.name.split('.')[0]: str(f)
191
- for f in Path(cache_dir).glob('*.pkl')
192
- if f.name in allowed_checkpoints or not IS_SPACE
193
- }
194
- print('Valid checkpoint file:')
195
- print(valid_checkpoints_dict)
196
-
197
- init_pkl = 'stylegan_human_v2_512'
198
-
199
- with gr.Blocks() as app:
200
- gr.Markdown("""
201
- # DragGAN - Drag Your GAN
202
- ## Interactive Point-based Manipulation on the Generative Image Manifold
203
- ### Unofficial Gradio Demo
204
-
205
- **Due to high demand, only one model can be run at a time, or you can duplicate the space and run your own copy.**
206
-
207
- <a href="https://huggingface.co/spaces/radames/DragGan?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
208
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
209
-
210
- * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
211
- * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) Β© [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
212
- """)
213
-
214
- # renderer = Renderer()
215
- global_state = gr.State({
216
- "images": {
217
- # image_orig: the original image, change with seed/model is changed
218
- # image_raw: image with mask and points, change durning optimization
219
- # image_show: image showed on screen
220
- },
221
- "temporal_params": {
222
- # stop
223
- },
224
- 'mask':
225
- None, # mask for visualization, 1 for editing and 0 for unchange
226
- 'last_mask': None, # last edited mask
227
- 'show_mask': True, # add button
228
- "generator_params": dnnlib.EasyDict(),
229
- "params": {
230
- "seed": int(np.random.randint(0, 2**32 - 1)),
231
- "motion_lambda": 20,
232
- "r1_in_pixels": 3,
233
- "r2_in_pixels": 12,
234
- "magnitude_direction_in_pixels": 1.0,
235
- "latent_space": "w+",
236
- "trunc_psi": 0.7,
237
- "trunc_cutoff": None,
238
- "lr": 0.001,
239
- },
240
- "device": device,
241
- "draw_interval": 1,
242
- "renderer": Renderer(disable_timing=True),
243
- "points": {},
244
- "curr_point": None,
245
- "curr_type_point": "start",
246
- 'editing_state': 'add_points',
247
- 'pretrained_weight': init_pkl
248
- })
249
-
250
- # init image
251
- global_state = init_images(global_state)
252
- with gr.Row():
253
-
254
- with gr.Row():
255
-
256
- # Left --> tools
257
- with gr.Column(scale=3):
258
-
259
- # Pickle
260
- with gr.Row():
261
-
262
- with gr.Column(scale=1, min_width=10):
263
- gr.Markdown(value='Pickle', show_label=False)
264
-
265
- with gr.Column(scale=4, min_width=10):
266
- form_pretrained_dropdown = gr.Dropdown(
267
- choices=list(valid_checkpoints_dict.keys()),
268
- label="Pretrained Model",
269
- value=init_pkl,
270
- )
271
-
272
- # Latent
273
- with gr.Row():
274
- with gr.Column(scale=1, min_width=10):
275
- gr.Markdown(value='Latent', show_label=False)
276
-
277
- with gr.Column(scale=4, min_width=10):
278
- form_seed_number = gr.Slider(
279
- mininium=0,
280
- maximum=2**32-1,
281
- step=1,
282
- value=global_state.value['params']['seed'],
283
- interactive=True,
284
- # randomize=True,
285
- label="Seed",
286
- )
287
- form_lr_number = gr.Number(
288
- value=global_state.value["params"]["lr"],
289
- interactive=True,
290
- label="Step Size")
291
-
292
- with gr.Row():
293
- with gr.Column(scale=2, min_width=10):
294
- form_reset_image = gr.Button("Reset Image")
295
- with gr.Column(scale=3, min_width=10):
296
- form_latent_space = gr.Radio(
297
- ['w', 'w+'],
298
- value=global_state.value['params']
299
- ['latent_space'],
300
- interactive=True,
301
- label='Latent space to optimize',
302
- show_label=False,
303
- )
304
-
305
- # Drag
306
- with gr.Row():
307
- with gr.Column(scale=1, min_width=10):
308
- gr.Markdown(value='Drag', show_label=False)
309
- with gr.Column(scale=4, min_width=10):
310
- with gr.Row():
311
- with gr.Column(scale=1, min_width=10):
312
- enable_add_points = gr.Button('Add Points')
313
- with gr.Column(scale=1, min_width=10):
314
- undo_points = gr.Button('Reset Points')
315
- with gr.Row():
316
- with gr.Column(scale=1, min_width=10):
317
- form_start_btn = gr.Button("Start")
318
- with gr.Column(scale=1, min_width=10):
319
- form_stop_btn = gr.Button("Stop")
320
-
321
- form_steps_number = gr.Number(value=0,
322
- label="Steps",
323
- interactive=False)
324
-
325
- # Mask
326
- with gr.Row():
327
- with gr.Column(scale=1, min_width=10):
328
- gr.Markdown(value='Mask', show_label=False)
329
- with gr.Column(scale=4, min_width=10):
330
- enable_add_mask = gr.Button('Edit Flexible Area')
331
- with gr.Row():
332
- with gr.Column(scale=1, min_width=10):
333
- form_reset_mask_btn = gr.Button("Reset mask")
334
- with gr.Column(scale=1, min_width=10):
335
- show_mask = gr.Checkbox(
336
- label='Show Mask',
337
- value=global_state.value['show_mask'],
338
- show_label=False)
339
-
340
- with gr.Row():
341
- form_lambda_number = gr.Number(
342
- value=global_state.value["params"]
343
- ["motion_lambda"],
344
- interactive=True,
345
- label="Lambda",
346
- )
347
-
348
- form_draw_interval_number = gr.Number(
349
- value=global_state.value["draw_interval"],
350
- label="Draw Interval (steps)",
351
- interactive=True,
352
- visible=False)
353
-
354
- # Right --> Image
355
- with gr.Column(scale=8):
356
- form_image = ImageMask(
357
- value=global_state.value['images']['image_show'],
358
- brush_radius=20).style(
359
- width=768,
360
- height=768) # NOTE: hard image size code here.
361
- gr.Markdown("""
362
- ## Quick Start
363
-
364
- 1. Select desired `Pretrained Model` and adjust `Seed` to generate an
365
- initial image.
366
- 2. Click on image to add control points.
367
- 3. Click `Start` and enjoy it!
368
-
369
- ## Advance Usage
370
-
371
- 1. Change `Step Size` to adjust learning rate in drag optimization.
372
- 2. Select `w` or `w+` to change latent space to optimize:
373
- * Optimize on `w` space may cause greater influence to the image.
374
- * Optimize on `w+` space may work slower than `w`, but usually achieve
375
- better results.
376
- * Note that changing the latent space will reset the image, points and
377
- mask (this has the same effect as `Reset Image` button).
378
- 3. Click `Edit Flexible Area` to create a mask and constrain the
379
- unmasked region to remain unchanged.
380
-
381
-
382
- """)
383
- gr.HTML("""
384
- <style>
385
- .container {
386
- position: absolute;
387
- height: 50px;
388
- text-align: center;
389
- line-height: 50px;
390
- width: 100%;
391
- }
392
- </style>
393
- <div class="container">
394
- Gradio demo supported by
395
- <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
396
- <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
397
- </div>
398
- """)
399
- # Network & latents tab listeners
400
-
401
- def on_change_pretrained_dropdown(pretrained_value, global_state):
402
- """Function to handle model change.
403
- 1. Set pretrained value to global_state
404
- 2. Re-init images and clear all states
405
- """
406
-
407
- global_state['pretrained_weight'] = pretrained_value
408
- init_images(global_state)
409
- clear_state(global_state)
410
-
411
- return global_state, global_state["images"]['image_show']
412
-
413
- form_pretrained_dropdown.change(
414
- on_change_pretrained_dropdown,
415
- inputs=[form_pretrained_dropdown, global_state],
416
- outputs=[global_state, form_image],
417
- queue=True,
418
- )
419
-
420
- def on_click_reset_image(global_state):
421
- """Reset image to the original one and clear all states
422
- 1. Re-init images
423
- 2. Clear all states
424
- """
425
-
426
- init_images(global_state)
427
- clear_state(global_state)
428
-
429
- return global_state, global_state['images']['image_show']
430
-
431
- form_reset_image.click(
432
- on_click_reset_image,
433
- inputs=[global_state],
434
- outputs=[global_state, form_image],
435
- queue=False,
436
- )
437
-
438
- # Update parameters
439
- def on_change_update_image_seed(seed, global_state):
440
- """Function to handle generation seed change.
441
- 1. Set seed to global_state
442
- 2. Re-init images and clear all states
443
- """
444
-
445
- global_state["params"]["seed"] = int(seed)
446
- init_images(global_state)
447
- clear_state(global_state)
448
-
449
- return global_state, global_state['images']['image_show']
450
-
451
- form_seed_number.change(
452
- on_change_update_image_seed,
453
- inputs=[form_seed_number, global_state],
454
- outputs=[global_state, form_image],
455
- )
456
-
457
- def on_click_latent_space(latent_space, global_state):
458
- """Function to reset latent space to optimize.
459
- NOTE: this function we reset the image and all controls
460
- 1. Set latent-space to global_state
461
- 2. Re-init images and clear all state
462
- """
463
-
464
- global_state['params']['latent_space'] = latent_space
465
- init_images(global_state)
466
- clear_state(global_state)
467
-
468
- return global_state, global_state['images']['image_show']
469
-
470
- form_latent_space.change(on_click_latent_space,
471
- inputs=[form_latent_space, global_state],
472
- outputs=[global_state, form_image])
473
-
474
- # ==== Params
475
- form_lambda_number.change(
476
- partial(on_change_single_global_state, ["params", "motion_lambda"]),
477
- inputs=[form_lambda_number, global_state],
478
- outputs=[global_state],
479
- )
480
-
481
- def on_change_lr(lr, global_state):
482
- if lr == 0:
483
- print('lr is 0, do nothing.')
484
- return global_state
485
- else:
486
- global_state["params"]["lr"] = lr
487
- renderer = global_state['renderer']
488
- renderer.update_lr(lr)
489
- print('New optimizer: ')
490
- print(renderer.w_optim)
491
- return global_state
492
-
493
- form_lr_number.change(
494
- on_change_lr,
495
- inputs=[form_lr_number, global_state],
496
- outputs=[global_state],
497
- queue=False,
498
- )
499
-
500
- def on_click_start(global_state, image):
501
- p_in_pixels = []
502
- t_in_pixels = []
503
- valid_points = []
504
-
505
- # handle of start drag in mask editing mode
506
- global_state = preprocess_mask_info(global_state, image)
507
-
508
- # Prepare the points for the inference
509
- if len(global_state["points"]) == 0:
510
- # yield on_click_start_wo_points(global_state, image)
511
- image_raw = global_state['images']['image_raw']
512
- update_image_draw(
513
- image_raw,
514
- global_state['points'],
515
- global_state['mask'],
516
- global_state['show_mask'],
517
- global_state,
518
- )
519
-
520
- yield (
521
- global_state,
522
- 0,
523
- global_state['images']['image_show'],
524
- # gr.File.update(visible=False),
525
- gr.Button.update(interactive=True),
526
- gr.Button.update(interactive=True),
527
- gr.Button.update(interactive=True),
528
- gr.Button.update(interactive=True),
529
- gr.Button.update(interactive=True),
530
- # latent space
531
- gr.Radio.update(interactive=True),
532
- gr.Button.update(interactive=True),
533
- # NOTE: disable stop button
534
- gr.Button.update(interactive=False),
535
-
536
- # update other comps
537
- gr.Dropdown.update(interactive=True),
538
- gr.Number.update(interactive=True),
539
- gr.Number.update(interactive=True),
540
- gr.Button.update(interactive=True),
541
- gr.Button.update(interactive=True),
542
- gr.Checkbox.update(interactive=True),
543
- # gr.Number.update(interactive=True),
544
- gr.Number.update(interactive=True),
545
- )
546
- else:
547
-
548
- # Transform the points into torch tensors
549
- for key_point, point in global_state["points"].items():
550
- try:
551
- p_start = point.get("start_temp", point["start"])
552
- p_end = point["target"]
553
-
554
- if p_start is None or p_end is None:
555
- continue
556
-
557
- except KeyError:
558
- continue
559
-
560
- p_in_pixels.append(p_start)
561
- t_in_pixels.append(p_end)
562
- valid_points.append(key_point)
563
-
564
- mask = torch.tensor(global_state['mask']).float()
565
- drag_mask = 1 - mask
566
-
567
- renderer: Renderer = global_state["renderer"]
568
- global_state['temporal_params']['stop'] = False
569
- global_state['editing_state'] = 'running'
570
-
571
- # reverse points order
572
- p_to_opt = reverse_point_pairs(p_in_pixels)
573
- t_to_opt = reverse_point_pairs(t_in_pixels)
574
- print('Running with:')
575
- print(f' Source: {p_in_pixels}')
576
- print(f' Target: {t_in_pixels}')
577
- step_idx = 0
578
- last_time = time.time()
579
- while True:
580
- print_memory_usage()
581
- # add a TIMEOUT break
582
- print(f'Running time: {time.time() - last_time}')
583
- if IS_SPACE and time.time() - last_time > TIMEOUT:
584
- print('Timeout break!')
585
- break
586
- if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
587
- break
588
-
589
- # do drage here!
590
- renderer._render_drag_impl(
591
- global_state['generator_params'],
592
- p_to_opt, # point
593
- t_to_opt, # target
594
- drag_mask, # mask,
595
- global_state['params']['motion_lambda'], # lambda_mask
596
- reg=0,
597
- feature_idx=5, # NOTE: do not support change for now
598
- r1=global_state['params']['r1_in_pixels'], # r1
599
- r2=global_state['params']['r2_in_pixels'], # r2
600
- # random_seed = 0,
601
- # noise_mode = 'const',
602
- trunc_psi=global_state['params']['trunc_psi'],
603
- # force_fp32 = False,
604
- # layer_name = None,
605
- # sel_channels = 3,
606
- # base_channel = 0,
607
- # img_scale_db = 0,
608
- # img_normalize = False,
609
- # untransform = False,
610
- is_drag=True,
611
- to_pil=True)
612
-
613
- if step_idx % global_state['draw_interval'] == 0:
614
- print('Current Source:')
615
- for key_point, p_i, t_i in zip(valid_points, p_to_opt,
616
- t_to_opt):
617
- global_state["points"][key_point]["start_temp"] = [
618
- p_i[1],
619
- p_i[0],
620
- ]
621
- global_state["points"][key_point]["target"] = [
622
- t_i[1],
623
- t_i[0],
624
- ]
625
- start_temp = global_state["points"][key_point][
626
- "start_temp"]
627
- print(f' {start_temp}')
628
-
629
- image_result = global_state['generator_params']['image']
630
- image_draw = update_image_draw(
631
- image_result,
632
- global_state['points'],
633
- global_state['mask'],
634
- global_state['show_mask'],
635
- global_state,
636
- )
637
- global_state['images']['image_raw'] = image_result
638
-
639
- yield (
640
- global_state,
641
- step_idx,
642
- global_state['images']['image_show'],
643
- # gr.File.update(visible=False),
644
- gr.Button.update(interactive=False),
645
- gr.Button.update(interactive=False),
646
- gr.Button.update(interactive=False),
647
- gr.Button.update(interactive=False),
648
- gr.Button.update(interactive=False),
649
- # latent space
650
- gr.Radio.update(interactive=False),
651
- gr.Button.update(interactive=False),
652
- # enable stop button in loop
653
- gr.Button.update(interactive=True),
654
-
655
- # update other comps
656
- gr.Dropdown.update(interactive=False),
657
- gr.Number.update(interactive=False),
658
- gr.Number.update(interactive=False),
659
- gr.Button.update(interactive=False),
660
- gr.Button.update(interactive=False),
661
- gr.Checkbox.update(interactive=False),
662
- # gr.Number.update(interactive=False),
663
- gr.Number.update(interactive=False),
664
- )
665
-
666
- # increate step
667
- step_idx += 1
668
-
669
- image_result = global_state['generator_params']['image']
670
- global_state['images']['image_raw'] = image_result
671
- image_draw = update_image_draw(image_result,
672
- global_state['points'],
673
- global_state['mask'],
674
- global_state['show_mask'],
675
- global_state)
676
-
677
- # fp = NamedTemporaryFile(suffix=".png", delete=False)
678
- # image_result.save(fp, "PNG")
679
-
680
- global_state['editing_state'] = 'add_points'
681
-
682
- yield (
683
- global_state,
684
- 0, # reset step to 0 after stop.
685
- global_state['images']['image_show'],
686
- # gr.File.update(visible=True, value=fp.name),
687
- gr.Button.update(interactive=True),
688
- gr.Button.update(interactive=True),
689
- gr.Button.update(interactive=True),
690
- gr.Button.update(interactive=True),
691
- gr.Button.update(interactive=True),
692
- # latent space
693
- gr.Radio.update(interactive=True),
694
- gr.Button.update(interactive=True),
695
- # NOTE: disable stop button with loop finish
696
- gr.Button.update(interactive=False),
697
-
698
- # update other comps
699
- gr.Dropdown.update(interactive=True),
700
- gr.Number.update(interactive=True),
701
- gr.Number.update(interactive=True),
702
- gr.Checkbox.update(interactive=True),
703
- gr.Number.update(interactive=True),
704
- )
705
-
706
- form_start_btn.click(
707
- on_click_start,
708
- inputs=[global_state, form_image],
709
- outputs=[
710
- global_state,
711
- form_steps_number,
712
- form_image,
713
- # form_download_result_file,
714
- # >>> buttons
715
- form_reset_image,
716
- enable_add_points,
717
- enable_add_mask,
718
- undo_points,
719
- form_reset_mask_btn,
720
- form_latent_space,
721
- form_start_btn,
722
- form_stop_btn,
723
- # <<< buttonm
724
- # >>> inputs comps
725
- form_pretrained_dropdown,
726
- form_seed_number,
727
- form_lr_number,
728
- show_mask,
729
- form_lambda_number,
730
- ],
731
- )
732
-
733
- def on_click_stop(global_state):
734
- """Function to handle stop button is clicked.
735
- 1. send a stop signal by set global_state["temporal_params"]["stop"] as True
736
- 2. Disable Stop button
737
- """
738
- global_state["temporal_params"]["stop"] = True
739
-
740
- return global_state, gr.Button.update(interactive=False)
741
-
742
- form_stop_btn.click(on_click_stop,
743
- inputs=[global_state],
744
- outputs=[global_state, form_stop_btn],
745
- queue=False)
746
-
747
- form_draw_interval_number.change(
748
- partial(
749
- on_change_single_global_state,
750
- "draw_interval",
751
- map_transform=lambda x: int(x),
752
- ),
753
- inputs=[form_draw_interval_number, global_state],
754
- outputs=[global_state],
755
- queue=False,
756
- )
757
-
758
- def on_click_remove_point(global_state):
759
- choice = global_state["curr_point"]
760
- del global_state["points"][choice]
761
-
762
- choices = list(global_state["points"].keys())
763
-
764
- if len(choices) > 0:
765
- global_state["curr_point"] = choices[0]
766
-
767
- return (
768
- gr.Dropdown.update(choices=choices, value=choices[0]),
769
- global_state,
770
- )
771
-
772
- # Mask
773
- def on_click_reset_mask(global_state):
774
- global_state['mask'] = np.ones(
775
- (
776
- global_state["images"]["image_raw"].size[1],
777
- global_state["images"]["image_raw"].size[0],
778
- ),
779
- dtype=np.uint8,
780
- )
781
- image_draw = update_image_draw(global_state['images']['image_raw'],
782
- global_state['points'],
783
- global_state['mask'],
784
- global_state['show_mask'], global_state)
785
- return global_state, image_draw
786
-
787
- form_reset_mask_btn.click(
788
- on_click_reset_mask,
789
- inputs=[global_state],
790
- outputs=[global_state, form_image],
791
- )
792
-
793
- # Image
794
- def on_click_enable_draw(global_state, image):
795
- """Function to start add mask mode.
796
- 1. Preprocess mask info from last state
797
- 2. Change editing state to add_mask
798
- 3. Set curr image with points and mask
799
- """
800
- global_state = preprocess_mask_info(global_state, image)
801
- global_state['editing_state'] = 'add_mask'
802
- image_raw = global_state['images']['image_raw']
803
- image_draw = update_image_draw(image_raw, global_state['points'],
804
- global_state['mask'], True,
805
- global_state)
806
- return (global_state,
807
- gr.Image.update(value=image_draw, interactive=True))
808
-
809
- def on_click_remove_draw(global_state, image):
810
- """Function to start remove mask mode.
811
- 1. Preprocess mask info from last state
812
- 2. Change editing state to remove_mask
813
- 3. Set curr image with points and mask
814
- """
815
- global_state = preprocess_mask_info(global_state, image)
816
- global_state['edinting_state'] = 'remove_mask'
817
- image_raw = global_state['images']['image_raw']
818
- image_draw = update_image_draw(image_raw, global_state['points'],
819
- global_state['mask'], True,
820
- global_state)
821
- return (global_state,
822
- gr.Image.update(value=image_draw, interactive=True))
823
-
824
- enable_add_mask.click(on_click_enable_draw,
825
- inputs=[global_state, form_image],
826
- outputs=[
827
- global_state,
828
- form_image,
829
- ],
830
- queue=False)
831
-
832
- def on_click_add_point(global_state, image: dict):
833
- """Function switch from add mask mode to add points mode.
834
- 1. Updaste mask buffer if need
835
- 2. Change global_state['editing_state'] to 'add_points'
836
- 3. Set current image with mask
837
- """
838
-
839
- global_state = preprocess_mask_info(global_state, image)
840
- global_state['editing_state'] = 'add_points'
841
- mask = global_state['mask']
842
- image_raw = global_state['images']['image_raw']
843
- image_draw = update_image_draw(image_raw, global_state['points'], mask,
844
- global_state['show_mask'], global_state)
845
-
846
- return (global_state,
847
- gr.Image.update(value=image_draw, interactive=False))
848
-
849
- enable_add_points.click(on_click_add_point,
850
- inputs=[global_state, form_image],
851
- outputs=[global_state, form_image],
852
- queue=False)
853
-
854
- def on_click_image(global_state, evt: gr.SelectData):
855
- """This function only support click for point selection
856
- """
857
- xy = evt.index
858
- if global_state['editing_state'] != 'add_points':
859
- print(f'In {global_state["editing_state"]} state. '
860
- 'Do not add points.')
861
-
862
- return global_state, global_state['images']['image_show']
863
-
864
- points = global_state["points"]
865
-
866
- point_idx = get_latest_points_pair(points)
867
- if point_idx is None:
868
- points[0] = {'start': xy, 'target': None}
869
- print(f'Click Image - Start - {xy}')
870
- elif points[point_idx].get('target', None) is None:
871
- points[point_idx]['target'] = xy
872
- print(f'Click Image - Target - {xy}')
873
- else:
874
- points[point_idx + 1] = {'start': xy, 'target': None}
875
- print(f'Click Image - Start - {xy}')
876
-
877
- image_raw = global_state['images']['image_raw']
878
- image_draw = update_image_draw(
879
- image_raw,
880
- global_state['points'],
881
- global_state['mask'],
882
- global_state['show_mask'],
883
- global_state,
884
- )
885
-
886
- return global_state, image_draw
887
-
888
- form_image.select(
889
- on_click_image,
890
- inputs=[global_state],
891
- outputs=[global_state, form_image],
892
- queue=False,
893
- )
894
-
895
- def on_click_clear_points(global_state):
896
- """Function to handle clear all control points
897
- 1. clear global_state['points'] (clear_state)
898
- 2. re-init network
899
- 2. re-draw image
900
- """
901
- clear_state(global_state, target='point')
902
-
903
- renderer: Renderer = global_state["renderer"]
904
- renderer.feat_refs = None
905
-
906
- image_raw = global_state['images']['image_raw']
907
- image_draw = update_image_draw(image_raw, {}, global_state['mask'],
908
- global_state['show_mask'], global_state)
909
- return global_state, image_draw
910
-
911
- undo_points.click(on_click_clear_points,
912
- inputs=[global_state],
913
- outputs=[global_state, form_image],
914
- queue=False)
915
-
916
- def on_click_show_mask(global_state, show_mask):
917
- """Function to control whether show mask on image."""
918
- global_state['show_mask'] = show_mask
919
-
920
- image_raw = global_state['images']['image_raw']
921
- image_draw = update_image_draw(
922
- image_raw,
923
- global_state['points'],
924
- global_state['mask'],
925
- global_state['show_mask'],
926
- global_state,
927
- )
928
- return global_state, image_draw
929
-
930
- show_mask.change(
931
- on_click_show_mask,
932
- inputs=[global_state, show_mask],
933
- outputs=[global_state, form_image],
934
- queue=False,
935
- )
936
-
937
- print("SHAReD: Start app", parser.parse_args())
938
- gr.close_all()
939
- app.queue(concurrency_count=1, max_size=200, api_open=False)
940
- app.launch(share=args.share, show_api=False)