Data
Generating and consolidating the data.
- data.generate_fake_images(selected_classes, class_index_json, output_dir, max_images_per_class=1000, device='cuda')[source]
Generates fake images for selected classes using Stable Diffusion.
This function uses the Stable Diffusion model to generate fake images for the specified classes. It takes class identifiers (WNID) from selected_classes and retrieves the corresponding class names from a JSON file (class_index_json). It generates images based on the class names and saves them to the output_dir. A maximum of max_images_per_class images are generated for each class. If images already exist for a class, it skips generating additional images.
- Parameters:
selected_classes (list of str) – A list of class identifiers (WNIDs) for which fake images should be generated.
class_index_json (str) – Path to a JSON file that maps class indices to class names.
output_dir (str) – Directory where the generated images will be saved.
max_images_per_class (int, optional) – The maximum number of images to generate per class. Default is 1000.
device (str, optional) – The device on which the model should run (e.g., ‘cuda’ or ‘cpu’). Default is ‘cuda’.
- Returns:
None
Example
selected_classes = [“n02096585”, “n02129604”] class_index_json = “path/to/class_index.json” output_dir = “path/to/output” generate_fake_images(selected_classes, class_index_json, output_dir) # Fake images for the selected classes are generated and saved in the output directory.
- data.load_real_images(data_dir, selected_classes, max_images_per_class=1000)[source]
Loads a subset of real images from a directory, filtering by selected classes and limiting the number of images per class.
This function loads images from the specified directory, selecting only those belonging to the classes provided in selected_classes. It ensures that no more than max_images_per_class images are loaded per class. The images are preprocessed by resizing and centering them to 256x256 pixels and converting them to tensor format.
- Parameters:
data_dir (str) – The directory containing the image data organized into subfolders, where each subfolder corresponds to a class.
selected_classes (list of str) – A list of class names to filter the dataset by.
max_images_per_class (int, optional) – The maximum number of images to load per selected class. Default is 1000.
- Returns:
A subset of the ImageFolder dataset containing the filtered and preprocessed images.
- Return type:
torch.utils.data.Subset
Example
data_dir = “path/to/imagenet” selected_classes = [“tiger”, “koala”, “hamster”] dataset = load_real_images(data_dir, selected_classes, max_images_per_class=500) print(len(dataset)) # Prints the number of images loaded.
- data.prepare_image_paths_labels(real_dataset, fake_dir, selected_classes)[source]
Prepares a combined list of image paths and labels from real and fake datasets.
This function creates a list of image paths and their corresponding labels for both real and fake images. For real images, the paths and labels are obtained from the real_dataset. For fake images, the paths are gathered from the specified fake_dir for each class in selected_classes. The real images are labeled as 0, and the fake images are labeled as 1. The function returns a combined list of tuples containing image paths and their associated labels.
- Parameters:
real_dataset (torch.utils.data.Subset) – A subset of the real image dataset, typically with transformations applied.
fake_dir (str) – Directory where fake images are stored, organized by class.
selected_classes (list of str) – A list of selected classes for which fake images should be included.
- Returns:
- A list of tuples, where each tuple contains:
str: The file path to an image.
int: The label for the image (0 for real, 1 for fake).
- Return type:
list of tuples
Example
real_dataset = load_real_images(“path/to/real_images”, selected_classes=[“hamster”, “zebra”]) fake_dir = “path/to/fake_images” selected_classes = [“hamster”, “zebra”] image_paths_labels = prepare_image_paths_labels(real_dataset, fake_dir, selected_classes) print(image_paths_labels) # Prints the list of image paths with labels.