diff --git a/.gitignore b/.gitignore index 0ada365..a8f7742 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .venv -engines \ No newline at end of file +engines +/__pycache__ +/utils/__pycache__ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f6a6498 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,56 @@ +accelerate==1.8.1 +antlr4-python3-runtime==4.9.3 +certifi==2022.12.7 +charset-normalizer==2.1.1 +colorama==0.4.6 +colored==2.3.0 +coloredlogs==15.0.1 +cuda-bindings==12.9.0 +cuda-python==12.9.0 +diffusers==0.24.0 +filelock==3.13.1 +fire==0.7.0 +flatbuffers==25.2.10 +fsspec==2024.6.1 +huggingface-hub==0.25.2 +humanfriendly==10.0 +idna==3.4 +importlib_metadata==8.7.0 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.3 +numpy==1.26.4 +oauthlib==3.3.1 +omegaconf==2.3.0 +onnx==1.15.0 +onnx_graphsurgeon==0.5.8 +onnxruntime==1.16.3 +packaging==25.0 +pillow==11.0.0 +polygraphy==0.49.24 +protobuf==3.20.2 +psutil==7.0.0 +PyOpenGL==3.1.9 +pyreadline3==3.5.4 +python-osc==1.9.3 +pywin32==311 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.28.1 +requests-oauthlib==2.0.0 +safetensors==0.5.3 +SpoutGL==0.1.1 +streamdiffusion @ git+https://github.com/cumulo-autumn/StreamDiffusion.git@b623251dc055e1fd858d53509aa43e09dfc5cdc0 +sympy==1.13.3 +termcolor==3.1.0 +tokenizers==0.15.2 +torch==2.1.0+cu121 +torchvision==0.16.0+cu121 +tqdm==4.67.1 +transformers==4.35.2 +twython==3.9.1 +typing_extensions==4.12.2 +urllib3==1.26.13 +xformers==0.0.22.post7 +zipp==3.23.0 \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/viewer.py b/utils/viewer.py new file mode 100644 index 0000000..dd6f6ca --- /dev/null +++ b/utils/viewer.py @@ -0,0 +1,98 @@ +import os +import sys +import threading +import time +import tkinter as tk +from multiprocessing import Queue +from typing import List +from PIL import Image, ImageTk +from streamdiffusion.image_utils import postprocess_image + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + + +def update_image(image_data: Image.Image, label: tk.Label) -> None: + """ + Update the image displayed on a Tkinter label. + + Parameters + ---------- + image_data : Image.Image + The image to be displayed. + label : tk.Label + The labels where the image will be updated. + """ + width = 512 + height = 512 + tk_image = ImageTk.PhotoImage(image_data, size=width) + label.configure(image=tk_image, width=width, height=height) + label.image = tk_image # keep a reference + +def _receive_images( + queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label +) -> None: + """ + Continuously receive images from a queue and update the labels. + + Parameters + ---------- + queue : Queue + The queue to receive images from. + fps_queue : Queue + The queue to put the calculated fps. + label : tk.Label + The label to update with images. + fps_label : tk.Label + The label to show fps. + """ + while True: + try: + if not queue.empty(): + label.after( + 0, + update_image, + postprocess_image(queue.get(block=False), output_type="pil")[0], + label, + ) + if not fps_queue.empty(): + fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}") + + time.sleep(0.0005) + except KeyboardInterrupt: + return + + +def receive_images(queue: Queue, fps_queue: Queue) -> None: + """ + Setup the Tkinter window and start the thread to receive images. + + Parameters + ---------- + queue : Queue + The queue to receive images from. + fps_queue : Queue + The queue to put the calculated fps. + """ + root = tk.Tk() + root.title("Image Viewer") + label = tk.Label(root) + fps_label = tk.Label(root, text="FPS: 0") + label.grid(column=0) + fps_label.grid(column=1) + + def on_closing(): + print("window closed") + root.quit() # stop event loop + return + + thread = threading.Thread( + target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True + ) + thread.start() + + try: + root.protocol("WM_DELETE_WINDOW", on_closing) + root.mainloop() + except KeyboardInterrupt: + return + diff --git a/utils/wrapper.py b/utils/wrapper.py new file mode 100644 index 0000000..5c6ac12 --- /dev/null +++ b/utils/wrapper.py @@ -0,0 +1,663 @@ +import gc +import os +from pathlib import Path +import traceback +from typing import List, Literal, Optional, Union, Dict + +import numpy as np +import torch +from diffusers import AutoencoderTiny, StableDiffusionPipeline +from PIL import Image + +from streamdiffusion import StreamDiffusion +from streamdiffusion.image_utils import postprocess_image + + +torch.set_grad_enabled(False) +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +class StreamDiffusionWrapper: + def __init__( + self, + model_id_or_path: str, + t_index_list: List[int], + lora_dict: Optional[Dict[str, float]] = None, + mode: Literal["img2img", "txt2img"] = "img2img", + output_type: Literal["pil", "pt", "np", "latent"] = "pil", + lcm_lora_id: Optional[str] = None, + vae_id: Optional[str] = None, + device: Literal["cpu", "cuda"] = "cuda", + dtype: torch.dtype = torch.float16, + frame_buffer_size: int = 1, + width: int = 512, + height: int = 512, + warmup: int = 10, + acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", + do_add_noise: bool = True, + device_ids: Optional[List[int]] = None, + use_lcm_lora: bool = True, + use_tiny_vae: bool = True, + enable_similar_image_filter: bool = False, + similar_image_filter_threshold: float = 0.98, + similar_image_filter_max_skip_frame: int = 10, + use_denoising_batch: bool = True, + cfg_type: Literal["none", "full", "self", "initialize"] = "self", + seed: int = 2, + use_safety_checker: bool = False, + engine_dir: Optional[Union[str, Path]] = "engines", + ): + """ + Initializes the StreamDiffusionWrapper. + + Parameters + ---------- + model_id_or_path : str + The model id or path to load. + t_index_list : List[int] + The t_index_list to use for inference. + lora_dict : Optional[Dict[str, float]], optional + The lora_dict to load, by default None. + Keys are the LoRA names and values are the LoRA scales. + Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} + mode : Literal["img2img", "txt2img"], optional + txt2img or img2img, by default "img2img". + output_type : Literal["pil", "pt", "np", "latent"], optional + The output type of image, by default "pil". + lcm_lora_id : Optional[str], optional + The lcm_lora_id to load, by default None. + If None, the default LCM-LoRA + ("latent-consistency/lcm-lora-sdv1-5") will be used. + vae_id : Optional[str], optional + The vae_id to load, by default None. + If None, the default TinyVAE + ("madebyollin/taesd") will be used. + device : Literal["cpu", "cuda"], optional + The device to use for inference, by default "cuda". + dtype : torch.dtype, optional + The dtype for inference, by default torch.float16. + frame_buffer_size : int, optional + The frame buffer size for denoising batch, by default 1. + width : int, optional + The width of the image, by default 512. + height : int, optional + The height of the image, by default 512. + warmup : int, optional + The number of warmup steps to perform, by default 10. + acceleration : Literal["none", "xformers", "tensorrt"], optional + The acceleration method, by default "tensorrt". + do_add_noise : bool, optional + Whether to add noise for following denoising steps or not, + by default True. + device_ids : Optional[List[int]], optional + The device ids to use for DataParallel, by default None. + use_lcm_lora : bool, optional + Whether to use LCM-LoRA or not, by default True. + use_tiny_vae : bool, optional + Whether to use TinyVAE or not, by default True. + enable_similar_image_filter : bool, optional + Whether to enable similar image filter or not, + by default False. + similar_image_filter_threshold : float, optional + The threshold for similar image filter, by default 0.98. + similar_image_filter_max_skip_frame : int, optional + The max skip frame for similar image filter, by default 10. + use_denoising_batch : bool, optional + Whether to use denoising batch or not, by default True. + cfg_type : Literal["none", "full", "self", "initialize"], + optional + The cfg_type for img2img mode, by default "self". + You cannot use anything other than "none" for txt2img mode. + seed : int, optional + The seed, by default 2. + use_safety_checker : bool, optional + Whether to use safety checker or not, by default False. + """ + self.sd_turbo = "turbo" in model_id_or_path + + if mode == "txt2img": + if cfg_type != "none": + raise ValueError( + f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}" + ) + if use_denoising_batch and frame_buffer_size > 1: + if not self.sd_turbo: + raise ValueError( + "txt2img mode cannot use denoising batch with frame_buffer_size > 1." + ) + + if mode == "img2img": + if not use_denoising_batch: + raise NotImplementedError( + "img2img mode must use denoising batch for now." + ) + + self.device = device + self.dtype = dtype + self.width = width + self.height = height + self.mode = mode + self.output_type = output_type + self.frame_buffer_size = frame_buffer_size + self.batch_size = ( + len(t_index_list) * frame_buffer_size + if use_denoising_batch + else frame_buffer_size + ) + + self.use_denoising_batch = use_denoising_batch + self.use_safety_checker = use_safety_checker + + self.stream: StreamDiffusion = self._load_model( + model_id_or_path=model_id_or_path, + lora_dict=lora_dict, + lcm_lora_id=lcm_lora_id, + vae_id=vae_id, + t_index_list=t_index_list, + acceleration=acceleration, + warmup=warmup, + do_add_noise=do_add_noise, + use_lcm_lora=use_lcm_lora, + use_tiny_vae=use_tiny_vae, + cfg_type=cfg_type, + seed=seed, + engine_dir=engine_dir, + ) + + if device_ids is not None: + self.stream.unet = torch.nn.DataParallel( + self.stream.unet, device_ids=device_ids + ) + + if enable_similar_image_filter: + self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame) + + def prepare( + self, + prompt: str, + negative_prompt: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 1.2, + delta: float = 1.0, + ) -> None: + """ + Prepares the model for inference. + + Parameters + ---------- + prompt : str + The prompt to generate images from. + num_inference_steps : int, optional + The number of inference steps to perform, by default 50. + guidance_scale : float, optional + The guidance scale to use, by default 1.2. + delta : float, optional + The delta multiplier of virtual residual noise, + by default 1.0. + """ + self.stream.prepare( + prompt, + negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + delta=delta, + ) + + def __call__( + self, + image: Optional[Union[str, Image.Image, torch.Tensor]] = None, + prompt: Optional[str] = None, + ) -> Union[Image.Image, List[Image.Image]]: + """ + Performs img2img or txt2img based on the mode. + + Parameters + ---------- + image : Optional[Union[str, Image.Image, torch.Tensor]] + The image to generate from. + prompt : Optional[str] + The prompt to generate images from. + + Returns + ------- + Union[Image.Image, List[Image.Image]] + The generated image. + """ + if self.mode == "img2img": + return self.img2img(image, prompt) + else: + return self.txt2img(prompt) + + def txt2img( + self, prompt: Optional[str] = None + ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: + """ + Performs txt2img. + + Parameters + ---------- + prompt : Optional[str] + The prompt to generate images from. + + Returns + ------- + Union[Image.Image, List[Image.Image]] + The generated image. + """ + if prompt is not None: + self.stream.update_prompt(prompt) + + if self.sd_turbo: + image_tensor = self.stream.txt2img_sd_turbo(self.batch_size) + else: + image_tensor = self.stream.txt2img(self.frame_buffer_size) + image = self.postprocess_image(image_tensor, output_type=self.output_type) + + if self.use_safety_checker: + safety_checker_input = self.feature_extractor( + image, return_tensors="pt" + ).to(self.device) + _, has_nsfw_concept = self.safety_checker( + images=image_tensor.to(self.dtype), + clip_input=safety_checker_input.pixel_values.to(self.dtype), + ) + image = self.nsfw_fallback_img if has_nsfw_concept[0] else image + + return image + + def img2img( + self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None + ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: + """ + Performs img2img. + + Parameters + ---------- + image : Union[str, Image.Image, torch.Tensor] + The image to generate from. + + Returns + ------- + Image.Image + The generated image. + """ + if prompt is not None: + self.stream.update_prompt(prompt) + + if isinstance(image, str) or isinstance(image, Image.Image): + image = self.preprocess_image(image) + + image_tensor = self.stream(image) + image = self.postprocess_image(image_tensor, output_type=self.output_type) + + if self.use_safety_checker: + safety_checker_input = self.feature_extractor( + image, return_tensors="pt" + ).to(self.device) + _, has_nsfw_concept = self.safety_checker( + images=image_tensor.to(self.dtype), + clip_input=safety_checker_input.pixel_values.to(self.dtype), + ) + image = self.nsfw_fallback_img if has_nsfw_concept[0] else image + + return image + + def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor: + """ + Preprocesses the image. + + Parameters + ---------- + image : Union[str, Image.Image, torch.Tensor] + The image to preprocess. + + Returns + ------- + torch.Tensor + The preprocessed image. + """ + if isinstance(image, str): + image = Image.open(image).convert("RGB").resize((self.width, self.height)) + if isinstance(image, Image.Image): + image = image.convert("RGB").resize((self.width, self.height)) + + return self.stream.image_processor.preprocess( + image, self.height, self.width + ).to(device=self.device, dtype=self.dtype) + + def postprocess_image( + self, image_tensor: torch.Tensor, output_type: str = "pil" + ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: + """ + Postprocesses the image. + + Parameters + ---------- + image_tensor : torch.Tensor + The image tensor to postprocess. + + Returns + ------- + Union[Image.Image, List[Image.Image]] + The postprocessed image. + """ + if self.frame_buffer_size > 1: + return postprocess_image(image_tensor.cpu(), output_type=output_type) + else: + return postprocess_image(image_tensor.cpu(), output_type=output_type)[0] + + def _load_model( + self, + model_id_or_path: str, + t_index_list: List[int], + lora_dict: Optional[Dict[str, float]] = None, + lcm_lora_id: Optional[str] = None, + vae_id: Optional[str] = None, + acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", + warmup: int = 10, + do_add_noise: bool = True, + use_lcm_lora: bool = True, + use_tiny_vae: bool = True, + cfg_type: Literal["none", "full", "self", "initialize"] = "self", + seed: int = 2, + engine_dir: Optional[Union[str, Path]] = "engines", + ) -> StreamDiffusion: + """ + Loads the model. + + This method does the following: + + 1. Loads the model from the model_id_or_path. + 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. + 3. Loads the VAE model from the vae_id if needed. + 4. Enables acceleration if needed. + 5. Prepares the model for inference. + 6. Load the safety checker if needed. + + Parameters + ---------- + model_id_or_path : str + The model id or path to load. + t_index_list : List[int] + The t_index_list to use for inference. + lora_dict : Optional[Dict[str, float]], optional + The lora_dict to load, by default None. + Keys are the LoRA names and values are the LoRA scales. + Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} + lcm_lora_id : Optional[str], optional + The lcm_lora_id to load, by default None. + vae_id : Optional[str], optional + The vae_id to load, by default None. + acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional + The acceleration method, by default "tensorrt". + warmup : int, optional + The number of warmup steps to perform, by default 10. + do_add_noise : bool, optional + Whether to add noise for following denoising steps or not, + by default True. + use_lcm_lora : bool, optional + Whether to use LCM-LoRA or not, by default True. + use_tiny_vae : bool, optional + Whether to use TinyVAE or not, by default True. + cfg_type : Literal["none", "full", "self", "initialize"], + optional + The cfg_type for img2img mode, by default "self". + You cannot use anything other than "none" for txt2img mode. + seed : int, optional + The seed, by default 2. + + Returns + ------- + StreamDiffusion + The loaded model. + """ + + try: # Load from local directory + pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( + model_id_or_path, + ).to(device=self.device, dtype=self.dtype) + + except ValueError: # Load from huggingface + pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file( + model_id_or_path, + ).to(device=self.device, dtype=self.dtype) + except Exception: # No model found + traceback.print_exc() + print("Model load has failed. Doesn't exist.") + exit() + + stream = StreamDiffusion( + pipe=pipe, + t_index_list=t_index_list, + torch_dtype=self.dtype, + width=self.width, + height=self.height, + do_add_noise=do_add_noise, + frame_buffer_size=self.frame_buffer_size, + use_denoising_batch=self.use_denoising_batch, + cfg_type=cfg_type, + ) + if not self.sd_turbo: + if use_lcm_lora: + if lcm_lora_id is not None: + stream.load_lcm_lora( + pretrained_model_name_or_path_or_dict=lcm_lora_id + ) + else: + stream.load_lcm_lora() + stream.fuse_lora() + + if lora_dict is not None: + for lora_name, lora_scale in lora_dict.items(): + stream.load_lora(lora_name) + stream.fuse_lora(lora_scale=lora_scale) + print(f"Use LoRA: {lora_name} in weights {lora_scale}") + + if use_tiny_vae: + if vae_id is not None: + stream.vae = AutoencoderTiny.from_pretrained(vae_id).to( + device=pipe.device, dtype=pipe.dtype + ) + else: + stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( + device=pipe.device, dtype=pipe.dtype + ) + + try: + if acceleration == "xformers": + stream.pipe.enable_xformers_memory_efficient_attention() + if acceleration == "tensorrt": + from polygraphy import cuda + from streamdiffusion.acceleration.tensorrt import ( + TorchVAEEncoder, + compile_unet, + compile_vae_decoder, + compile_vae_encoder, + ) + from streamdiffusion.acceleration.tensorrt.engine import ( + AutoencoderKLEngine, + UNet2DConditionModelEngine, + ) + from streamdiffusion.acceleration.tensorrt.models import ( + VAE, + UNet, + VAEEncoder, + ) + + def create_prefix( + model_id_or_path: str, + max_batch_size: int, + min_batch_size: int, + ): + maybe_path = Path(model_id_or_path) + if maybe_path.exists(): + return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" + else: + return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" + + engine_dir = Path(engine_dir) + unet_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=stream.trt_unet_batch_size, + min_batch_size=stream.trt_unet_batch_size, + ), + "unet.engine", + ) + vae_encoder_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ), + "vae_encoder.engine", + ) + vae_decoder_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ), + "vae_decoder.engine", + ) + + if not os.path.exists(unet_path): + os.makedirs(os.path.dirname(unet_path), exist_ok=True) + unet_model = UNet( + fp16=True, + device=stream.device, + max_batch_size=stream.trt_unet_batch_size, + min_batch_size=stream.trt_unet_batch_size, + embedding_dim=stream.text_encoder.config.hidden_size, + unet_dim=stream.unet.config.in_channels, + ) + compile_unet( + stream.unet, + unet_model, + unet_path + ".onnx", + unet_path + ".opt.onnx", + unet_path, + opt_batch_size=stream.trt_unet_batch_size, + ) + + if not os.path.exists(vae_decoder_path): + os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) + stream.vae.forward = stream.vae.decode + vae_decoder_model = VAE( + device=stream.device, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + compile_vae_decoder( + stream.vae, + vae_decoder_model, + vae_decoder_path + ".onnx", + vae_decoder_path + ".opt.onnx", + vae_decoder_path, + opt_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + delattr(stream.vae, "forward") + + if not os.path.exists(vae_encoder_path): + os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) + vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) + vae_encoder_model = VAEEncoder( + device=stream.device, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + compile_vae_encoder( + vae_encoder, + vae_encoder_model, + vae_encoder_path + ".onnx", + vae_encoder_path + ".opt.onnx", + vae_encoder_path, + opt_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + + cuda_stream = cuda.Stream() + + vae_config = stream.vae.config + vae_dtype = stream.vae.dtype + + stream.unet = UNet2DConditionModelEngine( + unet_path, cuda_stream, use_cuda_graph=False + ) + stream.vae = AutoencoderKLEngine( + vae_encoder_path, + vae_decoder_path, + cuda_stream, + stream.pipe.vae_scale_factor, + use_cuda_graph=False, + ) + setattr(stream.vae, "config", vae_config) + setattr(stream.vae, "dtype", vae_dtype) + + gc.collect() + torch.cuda.empty_cache() + + print("TensorRT acceleration enabled.") + if acceleration == "sfast": + from streamdiffusion.acceleration.sfast import ( + accelerate_with_stable_fast, + ) + + stream = accelerate_with_stable_fast(stream) + print("StableFast acceleration enabled.") + except Exception: + traceback.print_exc() + print("Acceleration has failed. Falling back to normal mode.") + + if seed < 0: # Random seed + seed = np.random.randint(0, 1000000) + + stream.prepare( + "", + "", + num_inference_steps=50, + guidance_scale=1.1 + if stream.cfg_type in ["full", "self", "initialize"] + else 1.0, + generator=torch.manual_seed(seed), + seed=seed, + ) + + if self.use_safety_checker: + from transformers import CLIPFeatureExtractor + from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, + ) + + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ).to(pipe.device) + self.feature_extractor = CLIPFeatureExtractor.from_pretrained( + "openai/clip-vit-base-patch32" + ) + self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0)) + + return stream