@@ -40,11 +40,14 @@ def _require_ctc_dataset(path, dataset_name, download):
40
40
41
41
data_path = os .path .join (path , dataset_name )
42
42
43
- if not os .path .exists (data_path ):
44
- url , checksum = CTC_URLS [dataset_name ], CTC_CHECKSUMS [dataset_name ]
45
- zip_path = os .path .join (path , f"{ dataset_name } .zip" )
46
- util .download_source (zip_path , url , download , checksum = checksum )
47
- util .unzip (zip_path , path , remove = True )
43
+ if os .path .exists (data_path ):
44
+ return data_path
45
+
46
+ os .makedirs (data_path )
47
+ url , checksum = CTC_URLS [dataset_name ], CTC_CHECKSUMS [dataset_name ]
48
+ zip_path = os .path .join (path , f"{ dataset_name } .zip" )
49
+ util .download_source (zip_path , url , download , checksum = checksum )
50
+ util .unzip (zip_path , path , remove = True )
48
51
49
52
return data_path
50
53
@@ -101,6 +104,8 @@ def get_ctc_segmentation_dataset(
101
104
splits = glob (os .path .join (data_path , "*_GT" ))
102
105
splits = [os .path .basename (split ) for split in splits ]
103
106
splits = [split .rstrip ("_GT" ) for split in splits ]
107
+ else :
108
+ splits = split
104
109
105
110
image_path , label_path = _require_gt_images (data_path , splits )
106
111
0 commit comments