Source code for light_side.datasets.zerodce

import os
from typing import Tuple

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torchvision.datasets.utils import extract_archive

from ..core import _download_file_from_url
from .base import BaseDataset, Identity


[docs]class ZeroDCE(BaseDataset): """ ZeroDCE is a dataset mentioned in the paper and in the original implementation. """ __phases__ = ( "train", "val", ) def __init__( self, root_dir: str = None, phase: str = "train", transforms=None, **kwargs, ): if root_dir is None: root_dir = "TODO" self.phase = phase self.root_dir = root_dir self.transforms = Identity() if transforms is None else transforms self.download() ids, targets = self._split_dataset(phase=phase) super().__init__(ids, targets, transforms=transforms, **kwargs) def __getitem__(self, idx: int) -> Tuple: img = self._load_image(self.ids[idx]) img = torch.from_numpy(img).float().permute(2, 0, 1) # apply transforms if self.transforms: img = self.transforms(img) return (img, img)
[docs] def _split_dataset(self, phase) -> Tuple: filenames = [] data_dir = os.path.join(self.root_dir) for item in os.listdir(data_dir): f = os.path.join(data_dir, item) if os.path.isfile(f): filenames.append(f) else: for subitem in os.listdir(f): sub_f = os.path.join(f, subitem) filenames.append(sub_f) filenames = np.asarray(filenames) filenames = filenames[filenames.argsort()] idxs = range(len(filenames)) # split into a is_train and test set as provided data is not presplit x_train, x_test, y_train, y_test = train_test_split( filenames, idxs, test_size=0.2, random_state=1, ) if phase == "train": return x_train.tolist(), y_train elif phase == "val": return x_test.tolist(), y_test else: raise ValueError("Unknown phase")
[docs] def _check_exists(self) -> bool: """ Check the Root directory is exists """ return os.path.exists(self.root_dir)
[docs] def download(self) -> None: """ Download the dataset from the internet """ if self._check_exists(): return os.makedirs(self.root_dir, exist_ok=True) _download_file_from_url( "https://drive.google.com/u/0/uc?id=1IXluAUo_3yFodOfr1clwxobQhfu6ApCJ&export=download&confirm=t", os.path.join(self.root_dir, "zerodce.zip"), ) extract_archive( os.path.join(self.root_dir, "zerodce.zip"), self.root_dir, remove_finished=True, )
if __name__ == "__main__": data = ZeroDCE("light_side/datas/zerodce") print(data[0]) print(data.classes) print(len(data.classes))