import os
import logging
import torch
from torchvision import transforms as trn
from torchvision.datasets import ImageFolder
from diffusers import StableDiffusionPipeline
import json
# -------------------------
# Load Real Images
# -------------------------
[docs]def load_real_images(data_dir, selected_classes, max_images_per_class=1000):
"""
Loads a subset of real images from a directory, filtering by selected classes and limiting the number of images per class.
This function loads images from the specified directory, selecting only those belonging to the classes
provided in `selected_classes`. It ensures that no more than `max_images_per_class` images are loaded
per class. The images are preprocessed by resizing and centering them to 256x256 pixels and converting
them to tensor format.
Args:
data_dir (str): The directory containing the image data organized into subfolders, where each subfolder
corresponds to a class.
selected_classes (list of str): A list of class names to filter the dataset by.
max_images_per_class (int, optional): The maximum number of images to load per selected class. Default is 1000.
Returns:
torch.utils.data.Subset: A subset of the `ImageFolder` dataset containing the filtered and preprocessed images.
Example:
data_dir = "path/to/imagenet"
selected_classes = ["tiger", "koala", "hamster"]
dataset = load_real_images(data_dir, selected_classes, max_images_per_class=500)
print(len(dataset)) # Prints the number of images loaded.
"""
#Real images taken from imagenet - 1000 images (small subset with 10 classes) which are: hamster, zebra, castle, fountain, koala, tiger, monarch butterfly, flamingo, knot, forklift
logging.info("Loading real images...")
transform_real = trn.Compose([
trn.Resize(256),
trn.CenterCrop(256),
trn.ToTensor(),
])
dataset = ImageFolder(root=data_dir, transform=transform_real)
class_to_idx = dataset.class_to_idx
selected_class_indices = [class_to_idx[cls] for cls in selected_classes if cls in class_to_idx]
indices = [i for i, (_, label) in enumerate(dataset.samples) if label in selected_class_indices]
class_counts = {cls_idx: 0 for cls_idx in selected_class_indices}
limited_indices = []
for idx in indices:
_, label = dataset.samples[idx]
if class_counts[label] < max_images_per_class:
limited_indices.append(idx)
class_counts[label] += 1
subset_dataset = torch.utils.data.Subset(dataset, limited_indices)
logging.info(f"Total real images loaded: {len(subset_dataset)}")
return subset_dataset
# -------------------------
# Generate Fake Images
# -------------------------
[docs]def generate_fake_images(selected_classes, class_index_json, output_dir, max_images_per_class=1000, device='cuda'):
"""
Generates fake images for selected classes using Stable Diffusion.
This function uses the Stable Diffusion model to generate fake images for the specified classes. It takes
class identifiers (WNID) from `selected_classes` and retrieves the corresponding class names from a JSON file
(`class_index_json`). It generates images based on the class names and saves them to the `output_dir`.
A maximum of `max_images_per_class` images are generated for each class. If images already exist for a class,
it skips generating additional images.
Args:
selected_classes (list of str): A list of class identifiers (WNIDs) for which fake images should be generated.
class_index_json (str): Path to a JSON file that maps class indices to class names.
output_dir (str): Directory where the generated images will be saved.
max_images_per_class (int, optional): The maximum number of images to generate per class. Default is 1000.
device (str, optional): The device on which the model should run (e.g., 'cuda' or 'cpu'). Default is 'cuda'.
Returns:
None
Example:
selected_classes = ["n02096585", "n02129604"]
class_index_json = "path/to/class_index.json"
output_dir = "path/to/output"
generate_fake_images(selected_classes, class_index_json, output_dir)
# Fake images for the selected classes are generated and saved in the output directory.
"""
#generates fake images using stable diffusion
logging.info("Checking existing fake images...")
os.makedirs(output_dir, exist_ok=True)
with open(class_index_json, 'r') as f:
class_index = json.load(f)
wnid_to_class_name = {value[1]: key for key, value in class_index.items()}
sd_pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
sd_pipeline.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
for wnid in selected_classes:
if wnid not in wnid_to_class_name:
logging.warning(f"WNID {wnid} not found in class index. Skipping.")
continue
class_name = wnid_to_class_name[wnid]
prompt = class_name.replace('_', ' ')
class_dir = os.path.join(output_dir, wnid)
os.makedirs(class_dir, exist_ok=True)
existing_images = [img for img in os.listdir(class_dir) if img.endswith('.png')]
if len(existing_images) >= max_images_per_class:
logging.info(f"Images for class {class_name} already exist. Skipping.")
continue
logging.info(f"Generating images for class: {class_name} (WNID: {wnid})")
for idx in range(max_images_per_class - len(existing_images)):
try:
image = sd_pipeline(prompt).images[0]
image.save(os.path.join(class_dir, f"{wnid}_{len(existing_images) + idx:05d}.png"))
except Exception as e:
logging.error(f"Failed to generate image for {class_name}, index {idx}: {e}")
logging.info("Fake image generation complete.")
# -------------------------
# Prepare Real and Fake Datasets
# -------------------------
[docs]def prepare_image_paths_labels(real_dataset, fake_dir, selected_classes):
"""
Prepares a combined list of image paths and labels from real and fake datasets.
This function creates a list of image paths and their corresponding labels for both real and fake images.
For real images, the paths and labels are obtained from the `real_dataset`. For fake images, the paths
are gathered from the specified `fake_dir` for each class in `selected_classes`. The real images are labeled as 0,
and the fake images are labeled as 1. The function returns a combined list of tuples containing image paths
and their associated labels.
Args:
real_dataset (torch.utils.data.Subset): A subset of the real image dataset, typically with transformations applied.
fake_dir (str): Directory where fake images are stored, organized by class.
selected_classes (list of str): A list of selected classes for which fake images should be included.
Returns:
list of tuples: A list of tuples, where each tuple contains:
- str: The file path to an image.
- int: The label for the image (0 for real, 1 for fake).
Example:
real_dataset = load_real_images("path/to/real_images", selected_classes=["hamster", "zebra"])
fake_dir = "path/to/fake_images"
selected_classes = ["hamster", "zebra"]
image_paths_labels = prepare_image_paths_labels(real_dataset, fake_dir, selected_classes)
print(image_paths_labels) # Prints the list of image paths with labels.
"""
real_image_paths_labels = [(real_dataset.dataset.samples[idx][0], 0) for idx in real_dataset.indices]
fake_image_paths_labels = []
for cls in selected_classes:
class_dir = os.path.join(fake_dir, cls)
if not os.path.exists(class_dir):
continue
fake_images = [os.path.join(class_dir, img_name) for img_name in os.listdir(class_dir) if img_name.endswith('.png')]
fake_image_paths_labels.extend([(img_path, 1) for img_path in fake_images])
return real_image_paths_labels + fake_image_paths_labels