"""
A deepforest callback
Callbacks must have the following methods on_epoch_begin, on_epoch_end, on_fit_end, on_fit_begin methods and inject model and epoch kwargs.
"""
from deepforest import visualize
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import glob
import tempfile
from pytorch_lightning import Callback
from deepforest import dataset
from deepforest import utilities
import torch
[docs]class images_callback(Callback):
"""Run evaluation on a file of annotations during training
Args:
model: pytorch model
csv_file: path to csv with columns, image_path, xmin, ymin, xmax, ymax, label
epoch: integer. current epoch
experiment: optional comet_ml experiment
savedir: optional, directory to save predicted images
project: whether to project image coordinates into geographic coordinations, see deepforest.evaluate
root_dir: root directory of images to search for 'image path' values from the csv file
iou_threshold: intersection-over-union threshold, see deepforest.evaluate
probability_threshold: minimum probablity for inclusion, see deepforest.evaluate
n: number of images to upload
every_n_epochs: run epoch interval
Returns:
None: either prints validation scores or logs them to a comet experiment
"""
def __init__(self, csv_file, root_dir, savedir, n=2, every_n_epochs=5):
self.savedir = savedir
self.root_dir = root_dir
self.n = n
#limit to n images
df = pd.read_csv(csv_file)
selected_images = np.random.choice(df.image_path.unique(), self.n)
df = df[df.image_path.isin(selected_images)]
df.to_csv("{}/image_callback.csv".format(savedir))
self.csv_file = "{}/image_callback.csv".format(savedir)
self.every_n_epochs = every_n_epochs
[docs] def log_images(self, pl_module):
boxes = pl_module.predict_file(self.csv_file, self.root_dir, savedir=self.savedir)
try:
saved_plots = glob.glob("{}/*.png".format(self.savedir))
for x in saved_plots:
pl_module.logger.experiment.log_image(x)
except Exception as e:
print(
"Could not find logger in ligthning module, skipping upload, images were saved to {}, error was rasied {}"
.format(self.savedir, e))
[docs] def on_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % self.every_n_epochs == 0:
print("Running image callback")
self.log_images(pl_module)