t8star commited on
Commit
011036f
·
verified ·
1 Parent(s): 533673d

Upload imagecreatemask.py

Browse files
Files changed (1) hide show
  1. imagecreatemask.py +109 -0
imagecreatemask.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms.functional as TF
5
+
6
+ def tensor2pil(image):
7
+ return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
8
+
9
+ def pil2tensor(image):
10
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
11
+
12
+ def tensor2mask(t: torch.Tensor) -> torch.Tensor:
13
+ size = t.size()
14
+ if (len(size) < 4):
15
+ return t
16
+ if size[3] == 1:
17
+ return t[:,:,:,0]
18
+ elif size[3] == 4:
19
+ # Use alpha if available
20
+ if torch.min(t[:, :, :, 3]).item() != 1.:
21
+ return t[:,:,:,3]
22
+ # Convert RGB to grayscale
23
+ return TF.rgb_to_grayscale(t.permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
24
+
25
+ class image_concat_mask:
26
+ def __init__(self):
27
+ pass
28
+
29
+ @classmethod
30
+ def INPUT_TYPES(cls):
31
+ return {
32
+ "required": {
33
+ "image1": ("IMAGE",),
34
+ },
35
+ "optional": {
36
+ "image2": ("IMAGE",),
37
+ "mask": ("MASK",),
38
+ }
39
+ }
40
+
41
+ RETURN_TYPES = ("IMAGE", "MASK",)
42
+ RETURN_NAMES = ("image", "mask")
43
+ FUNCTION = "image_concat_mask"
44
+ CATEGORY = "hhy"
45
+
46
+ def image_concat_mask(self, image1, image2=None, mask=None):
47
+ processed_images = []
48
+ masks = []
49
+
50
+ for idx, img1 in enumerate(image1):
51
+ # Convert tensor to PIL
52
+ pil_image1 = tensor2pil(img1)
53
+
54
+ # Get first image dimensions
55
+ width1, height1 = pil_image1.size
56
+
57
+ if image2 is not None and idx < len(image2):
58
+ # Use provided second image
59
+ pil_image2 = tensor2pil(image2[idx])
60
+ width2, height2 = pil_image2.size
61
+
62
+ # Resize image2 to match height of image1
63
+ new_width2 = int(width2 * (height1 / height2))
64
+ pil_image2 = pil_image2.resize((new_width2, height1), Image.Resampling.LANCZOS)
65
+ else:
66
+ # Create white image with same dimensions as image1
67
+ pil_image2 = Image.new('RGB', (width1, height1), 'white')
68
+ new_width2 = width1
69
+
70
+ # Create new image to hold both images side by side
71
+ combined_image = Image.new('RGB', (width1 + new_width2, height1))
72
+
73
+ # Paste both images
74
+ combined_image.paste(pil_image1, (0, 0))
75
+ combined_image.paste(pil_image2, (width1, 0))
76
+
77
+ # Convert combined image to tensor
78
+ combined_tensor = pil2tensor(combined_image)
79
+ processed_images.append(combined_tensor)
80
+
81
+ # Create mask (0 for left image area, 1 for right image area)
82
+ final_mask = torch.zeros((1, height1, width1 + new_width2))
83
+ final_mask[:, :, width1:] = 1.0 # Set right half to 1
84
+
85
+ # If mask is provided, subtract it from the right side
86
+ if mask is not None and idx < len(mask):
87
+ input_mask = mask[idx]
88
+ # Resize input mask to match height1
89
+ pil_input_mask = tensor2pil(input_mask)
90
+ pil_input_mask = pil_input_mask.resize((new_width2, height1), Image.Resampling.LANCZOS)
91
+ resized_input_mask = pil2tensor(pil_input_mask)
92
+
93
+ # Subtract input mask from the right side
94
+ final_mask[:, :, width1:] *= (1.0 - resized_input_mask)
95
+
96
+ masks.append(final_mask)
97
+
98
+ processed_images = torch.cat(processed_images, dim=0)
99
+ masks = torch.cat(masks, dim=0)
100
+
101
+ return (processed_images, masks)
102
+
103
+ NODE_CLASS_MAPPINGS = {
104
+ "image concat mask": image_concat_mask
105
+ }
106
+
107
+ NODE_DISPLAY_NAME_MAPPINGS = {
108
+ "image concat mask": "Image Concat with Mask"
109
+ }