Upload imagecreatemask.py
Browse files- imagecreatemask.py +109 -0
imagecreatemask.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
|
6 |
+
def tensor2pil(image):
|
7 |
+
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
8 |
+
|
9 |
+
def pil2tensor(image):
|
10 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
11 |
+
|
12 |
+
def tensor2mask(t: torch.Tensor) -> torch.Tensor:
|
13 |
+
size = t.size()
|
14 |
+
if (len(size) < 4):
|
15 |
+
return t
|
16 |
+
if size[3] == 1:
|
17 |
+
return t[:,:,:,0]
|
18 |
+
elif size[3] == 4:
|
19 |
+
# Use alpha if available
|
20 |
+
if torch.min(t[:, :, :, 3]).item() != 1.:
|
21 |
+
return t[:,:,:,3]
|
22 |
+
# Convert RGB to grayscale
|
23 |
+
return TF.rgb_to_grayscale(t.permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
|
24 |
+
|
25 |
+
class image_concat_mask:
|
26 |
+
def __init__(self):
|
27 |
+
pass
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def INPUT_TYPES(cls):
|
31 |
+
return {
|
32 |
+
"required": {
|
33 |
+
"image1": ("IMAGE",),
|
34 |
+
},
|
35 |
+
"optional": {
|
36 |
+
"image2": ("IMAGE",),
|
37 |
+
"mask": ("MASK",),
|
38 |
+
}
|
39 |
+
}
|
40 |
+
|
41 |
+
RETURN_TYPES = ("IMAGE", "MASK",)
|
42 |
+
RETURN_NAMES = ("image", "mask")
|
43 |
+
FUNCTION = "image_concat_mask"
|
44 |
+
CATEGORY = "hhy"
|
45 |
+
|
46 |
+
def image_concat_mask(self, image1, image2=None, mask=None):
|
47 |
+
processed_images = []
|
48 |
+
masks = []
|
49 |
+
|
50 |
+
for idx, img1 in enumerate(image1):
|
51 |
+
# Convert tensor to PIL
|
52 |
+
pil_image1 = tensor2pil(img1)
|
53 |
+
|
54 |
+
# Get first image dimensions
|
55 |
+
width1, height1 = pil_image1.size
|
56 |
+
|
57 |
+
if image2 is not None and idx < len(image2):
|
58 |
+
# Use provided second image
|
59 |
+
pil_image2 = tensor2pil(image2[idx])
|
60 |
+
width2, height2 = pil_image2.size
|
61 |
+
|
62 |
+
# Resize image2 to match height of image1
|
63 |
+
new_width2 = int(width2 * (height1 / height2))
|
64 |
+
pil_image2 = pil_image2.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
65 |
+
else:
|
66 |
+
# Create white image with same dimensions as image1
|
67 |
+
pil_image2 = Image.new('RGB', (width1, height1), 'white')
|
68 |
+
new_width2 = width1
|
69 |
+
|
70 |
+
# Create new image to hold both images side by side
|
71 |
+
combined_image = Image.new('RGB', (width1 + new_width2, height1))
|
72 |
+
|
73 |
+
# Paste both images
|
74 |
+
combined_image.paste(pil_image1, (0, 0))
|
75 |
+
combined_image.paste(pil_image2, (width1, 0))
|
76 |
+
|
77 |
+
# Convert combined image to tensor
|
78 |
+
combined_tensor = pil2tensor(combined_image)
|
79 |
+
processed_images.append(combined_tensor)
|
80 |
+
|
81 |
+
# Create mask (0 for left image area, 1 for right image area)
|
82 |
+
final_mask = torch.zeros((1, height1, width1 + new_width2))
|
83 |
+
final_mask[:, :, width1:] = 1.0 # Set right half to 1
|
84 |
+
|
85 |
+
# If mask is provided, subtract it from the right side
|
86 |
+
if mask is not None and idx < len(mask):
|
87 |
+
input_mask = mask[idx]
|
88 |
+
# Resize input mask to match height1
|
89 |
+
pil_input_mask = tensor2pil(input_mask)
|
90 |
+
pil_input_mask = pil_input_mask.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
91 |
+
resized_input_mask = pil2tensor(pil_input_mask)
|
92 |
+
|
93 |
+
# Subtract input mask from the right side
|
94 |
+
final_mask[:, :, width1:] *= (1.0 - resized_input_mask)
|
95 |
+
|
96 |
+
masks.append(final_mask)
|
97 |
+
|
98 |
+
processed_images = torch.cat(processed_images, dim=0)
|
99 |
+
masks = torch.cat(masks, dim=0)
|
100 |
+
|
101 |
+
return (processed_images, masks)
|
102 |
+
|
103 |
+
NODE_CLASS_MAPPINGS = {
|
104 |
+
"image concat mask": image_concat_mask
|
105 |
+
}
|
106 |
+
|
107 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
108 |
+
"image concat mask": "Image Concat with Mask"
|
109 |
+
}
|