TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial…

5-Minute Data Science Design Patterns I: Callback

Nok Chan
TDS Archive
Published in
5 min readJul 9, 2021

--

Unsplash Pavan Trikutam

Note: These series are written as a quick introduction to software design for data scientists, something that is lightweight than the Design Pattern Bible — Clean Code I wish exists when I first started to learn. Design patterns refer to reusable solutions to some common problems, and some happen to be useful for data science. There is a good chance that someone else has solved your problem before. When used wisely, it helps to reduce the complexity of your code.

Since there are many code blocks, here is the version with syntax highlight.

So, What is Callback after all?

Callback function, or call after, simply means a function will be called after another function. It is a piece of executable code (function) that passed as an argument to another function. [1]

def foo(x, callback=None):
print('foo!')
if callback:
callback(x)
return None
>>> foo('123')
foo!
>>> foo('123', print)
foo!
123

Here I pass the function print as a callback, hence the string 123 get printed after foo!.

Why do I need to use Callback?

Callback is very common in high-level deep learning libraries, most likely you will find them in the training loop.

  • fastai — fastai provide high-level API for PyTorch
  • Keras — the high-level API for Tensorflow
  • ignite — use event & handler, which provides more flexibility in their opinion
import numpy as np

# A boring training Loop
def train(x):
n_epochs = 3
n_batches = 2
loss = 20

for epoch in range(n_epochs):
for batch in range(n_batches):
loss = loss - 1 # Pretend we are training the model
return loss
>>> x = np.ones(10)
>>> train(x)
14

So, let’s say you now want to print the loss at the end of an epoch. You can just add 1 line of code.

The simple approach

def train_with_print(x):
n_epochs = 3
n_batches = 2
loss = 20
for epoch in range(n_epochs):
for batch in range(n_batches):
loss = loss - 1 # Pretend we are training the model
print(f'End of Epoch. Epoch: {epoch}, Loss: {loss}')
return loss
train_with_print(x)End of Epoch. Epoch: 0, Loss: 18
End of Epoch. Epoch: 1, Loss: 16
End of Epoch. Epoch: 2, Loss: 14

Callback approach

Or you call to add a PrintCallback, which does the same thing but with a bit more code.

class Callback:
def on_epoch_start(self, x):
pass
def on_epoch_end(self, x):
pass
def on_batch_start(self, x):
pass
def on_batch_end(self, x):
pass
class PrintCallback(Callback):
def on_epoch_end(self, x):
print(f'End of Epoch. Epoch: {epoch}, Loss: {x}')
def train_with_callback(x, callback=None):
n_epochs = 3
n_batches = 2
loss = 20
for epoch in range(n_epochs): callback.on_epoch_start(loss) for batch in range(n_batches):
callback.on_batch_start(loss)
loss = loss - 1 # Pretend we are training the model
callback.on_batch_end(loss)
callback.on_epoch_end(loss)
return loss
>>> train_with_callback(x, callback=PrintCallback())End of Epoch. Epoch: 2, Loss: 18
End of Epoch. Epoch: 2, Loss: 16
End of Epoch. Epoch: 2, Loss: 14

Usually, a callback defines a few particular event on_xxx_xxx, which indicates that the function will be executed according to the corresponding condition. So all callbacks will inherit the base class Callback, and override the desired function, here we only implemented the on_epoch_end method because we only want to show the loss at the end.

It may seem awkward to write so much more codes to do 1 simple thing, but there are good reasons. Consider now you need to add more features, how would you do it?

  • ModelCheckpoint
  • Early Stopping
  • LearningRateScheduler

You can just add code in the loop, but it will start growing into a really big function. It is impossible to test this function because it does 10 things at the same time. In addition, the extra code may not even be related to the training logic, they are just there to save the model or plot a chart. So, it is best to separate the logic, a function should only do 1 thing according to the Single Responsibility Principle. It helps you to reduce the complexity as you don’t need to worry if you will accidentally break 10 things, it is much easier to just consider one thing at a time.

When using the Callback Pattern, I can just implement a few more classes and the training loop is barely touched. I do have to change the training function a bit as it should accept more than 1 callback.

A Callbacks class that is wrap a list of callback

class Callbacks:
"""
It is the container for callback
"""

def __init__(self, callbacks):
self.callbacks = callbacks

def on_epoch_start(self, x):
for callback in self.callbacks:
callback.on_epoch_start(x)

def on_epoch_end(self, x):
for callback in self.callbacks:
callback.on_epoch_end(x)

def on_batch_start(self, x):
for callback in self.callbacks:
callback.on_batch_start(x)

def on_batch_end(self, x):
for callback in self.callbacks:
callback.on_batch_end(x)

Psuedo Implementation of the additional Callback


class PrintCallback(Callback):
def on_epoch_end(self, x):
print(f'[{type(self).__name__}]: End of Epoch. Epoch: {epoch}, Loss: {x}')

class ModelCheckPoint(Callback):
def on_epoch_end(self, x):
print(f'[{type(self).__name__}]: Save Model')


class EarlyStoppingCallback(Callback):
def on_epoch_end(self, x):
if loss < 3:
print(f'[{type(self).__name__}]: Early Stopped')


class LearningRateScheduler(Callback):
def on_batch_end(self, x):
print(f' [{type(self).__name__}]: Reduce learning rate')


def train_with_callbacks(x, callbacks=None):
n_epochs = 3
n_batches = 6
loss = 20

for epoch in range(n_epochs):

callbacks.on_epoch_start(loss) # on_epoch_start
for batch in range(n_batches):
callbacks.on_batch_start(loss) # on_batch_start
loss = loss - 1 # Pretend we are training the model
callbacks.on_batch_end(loss) # on_batch_end
callbacks.on_epoch_end(loss) # on_epoch_end
return loss

Here is the result.

>>> callbacks = Callbacks([PrintCallback(), ModelCheckPoint(),
EarlyStoppingCallback(), LearningRateScheduler()])
>>> train_with_callbacks(x, callbacks=callbacks)
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Epoch: 2, Loss: 14
[ModelCheckPoint]: Save Model
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Epoch: 2, Loss: 8
[ModelCheckPoint]: Save Model
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Epoch: 2, Loss: 2
[ModelCheckPoint]: Save Model

Hopefully, it convinces you Callback makes the code cleaner and easier to maintain. If you just use plain if-else statements, you may end up with a big chunk of if-else clauses.

Reference

  1. https://stackoverflow.com/questions/824234/what-is-a-callback-function

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.