PyTorch 中的 CocoCaptions (3)
请我喝杯咖啡☕
*备忘录:
cococaptions() 可以使用 ms coco 数据集,如下所示。 *这是针对带有 stuff_train2017.json 的 train2017、带有 stuff_val2017.json 的 val2017、带有 stuff_train2017.json 的 stuff_train2017_pixelmaps、带有 stuff_val2017.json 的 stuff_val2017_pixelmaps、带有 panoptic_train2017.json 的 panoptic_train2017、带有 panoptic_train2017.json 的 panoptic_val2017 panoptic_val2017.json 和 unlabeled2017 以及 image_info_unlabeled2017.json:
from torchvision.datasets import CocoCaptionsstf_train2017_data = CocoCaptions( root="data/coco/imgs/train2017", annFile="data/coco/anns/stuff_trainval2017/stuff_train2017.json")stf_val2017_data = CocoCaptions( root="data/coco/imgs/val2017", annFile="data/coco/anns/stuff_trainval2017/stuff_val2017.json")len(stf_train2017_data), len(stf_val2017_data)# (118287, 5000)pms_stf_train2017_data = CocoCaptions( root="data/coco/anns/stuff_trainval2017/stuff_train2017_pixelmaps", annFile="data/coco/anns/stuff_trainval2017/stuff_train2017.json")pms_stf_val2017_data = CocoCaptions( root="data/coco/anns/stuff_trainval2017/stuff_val2017_pixelmaps", annFile="data/coco/anns/stuff_trainval2017/stuff_val2017.json")len(pms_stf_train2017_data), len(pms_stf_val2017_data)# (118287, 5000)# pan_train2017_data = CocoCaptions(# root="data/coco/anns/panoptic_trainval2017/panoptic_train2017",# annFile="data/coco/anns/panoptic_trainval2017/panoptic_train2017.json"# ) # Error# pan_val2017_data = CocoCaptions(# root="data/coco/anns/panoptic_trainval2017/panoptic_val2017",# annFile="data/coco/anns/panoptic_trainval2017/panoptic_val2017.json"# ) # Errorunlabeled2017_data = CocoCaptions( root="data/coco/imgs/unlabeled2017", annFile="data/coco/anns/unlabeled2017/image_info_unlabeled2017.json")len(unlabeled2017_data)# 123403stf_train2017_data[2] # Errorstf_train2017_data[47] # Errorstf_train2017_data[64] # Errorstf_val2017_data[2] # Errorstf_val2017_data[47] # Errorstf_val2017_data[64] # Errorpms_stf_train2017_data[2] # Errorpms_stf_train2017_data[47] # Errorpms_stf_train2017_data[64] # Errorpms_stf_val2017_data[2] # Errorpms_stf_val2017_data[47] # Errorpms_stf_val2017_data[64] # Errorunlabeled2017_data[2]# (<PIL.Image.Image image mode=RGB size=640x427>, [])unlabeled2017_data[47]# (<PIL.Image.Image image mode=RGB size=428x640>, [])unlabeled2017_data[64]# (<PIL.Image.Image image mode=RGB size=640x480>, [])import matplotlib.pyplot as pltdef show_images(data, ims, main_title=None): file = data.root.split('/')[-1] fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(14, 8)) fig.suptitle(t=main_title, y=0.9, fontsize=14) for i, axis in zip(ims, axes.ravel()): if not data[i][1]: im, _ = data[i] axis.imshow(X=im) fig.tight_layout() plt.show()ims = (2, 47, 64)show_images(data=unlabeled2017_data, ims=ims, main_title="unlabeled2017_data")