forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_processor.py
More file actions
181 lines (156 loc) · 6.83 KB
/
image_processor.py
File metadata and controls
181 lines (156 loc) · 6.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Union
import numpy as np
import PIL
import torch
from PIL import Image
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
class VaeImageProcessor(ConfigMixin):
"""
Image Processor for VAE
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
):
super().__init__()
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
@staticmethod
def numpy_to_pt(images):
"""
Convert a numpy image to a pytorch tensor
"""
if images.ndim == 3:
images = images[..., None]
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
@staticmethod
def normalize(images):
"""
Normalize an image array to [-1,1]
"""
return 2.0 * images - 1.0
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images
def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
) -> torch.Tensor:
"""
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
)
if isinstance(image[0], PIL.Image.Image):
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
FutureWarning,
)
do_normalize = False
if do_normalize:
image = self.normalize(image)
return image
def postprocess(
self,
image,
output_type: str = "pil",
):
if isinstance(image, torch.Tensor) and output_type == "pt":
return image
image = self.pt_to_numpy(image)
if output_type == "np":
return image
elif output_type == "pil":
return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")