
Part1: Zero-shot Detection
The versatility of CLIP is quite amazing. Below, I want to demonstrate its zero-shot capabilities on various tasks, such as text-prompted detection.
- Original Paper: OpenAI CLIP
- Official Code: Github
CLIP's Performance

Training Efficiency:
CLIP is among one of the most efficient models with an accuracy of 41% at 400 million images, outperforming other models such as the Bag of Words Prediction (27%) and the Transformer Language Model (16%) at the same number of images. This means that CLIP trains much faster than other models within the same domain.
Generalization: CLIP has been trained with such a wide array of image styles that it is far more flexible and than other models like ImageNet. It is important to note that CLIP generalizes well with images that it was trained on, not images outside of its training domain.
Automatically generate proposal regions with selective search, compute their similarity with a natural language query in CLIP embedding space, and return the top-k detections with non-maximum suppression.
%%capture
!pip install ftfy regex tqdm matplotlib selectivesearch
!pip install git+https://github.com/openai/CLIP.git
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import clip
from PIL import Image
from torchvision import transforms
import selectivesearch
from collections import OrderedDict
def load_image(img_path, resize=None, pil=False):
image = Image.open(image_path).convert("RGB")
if resize is not None:
image = image.resize((resize, resize))
if pil:
return image
image = np.asarray(image).astype(np.float32) / 255.
return image
# Reference: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
def nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
# Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py
def vis_detections(im, dets, thresh=0.5, caption=None):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
top_idx = dets[:, -1].argmax()
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red' if i == top_idx else 'green', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:.3f}'.format(score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
if caption is not None:
plt.title(caption, fontsize=20)
plt.show()
%%capture
!pip install ftfy regex tqdm matplotlib selectivesearch
!pip install git+https://github.com/openai/CLIP.git
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import clip
from PIL import Image
from torchvision import transforms
import selectivesearch
from collections import OrderedDict
def load_image(img_path, resize=None, pil=False):
image = Image.open(image_path).convert("RGB")
if resize is not None:
image = image.resize((resize, resize))
if pil:
return image
image = np.asarray(image).astype(np.float32) / 255.
return image
# Reference: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
def nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
# Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py
def vis_detections(im, dets, thresh=0.5, caption=None):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
top_idx = dets[:, -1].argmax()
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red' if i == top_idx else 'green', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:.3f}'.format(score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
if caption is not None:
plt.title(caption, fontsize=20)
plt.show()
Generate Bounding Boxes with Selective Search
image_url = 'http://archive.jsonline.com/Services/image.ashx?domain=www.jsonline.com&file=30025294_messykitchen1.jpg&resize=' #@param {type:"string"}
resize = None#@param {type:"raw"}
topk = 50#@param {type:"integer"}
scale = 200#@param {type:"integer"}
sigma = 0.8#@param {type:"number"}
min_size = 50#@param {type:"integer"}
# Download the image from the web.
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)
if resize is not None:
assert isinstance(resize, int), "resize should be an integer."
img = load_image(image_path)
oh, ow = img.shape[:2]
print(f"Image resolution: {oh, ow}")
# Selective search.
img_search = load_image(image_path, resize=resize)
img_lbl, regions = selectivesearch.selective_search(
img_search, scale=scale, sigma=sigma, min_size=min_size)
candidates = OrderedDict()
for i, r in enumerate(regions):
if r['rect'] in candidates:
continue
if r['size'] < 1000:
continue
x, y, w, h = r['rect']
if w / h > 1.5 or h / w > 1.5:
continue
if resize is not None:
sx = (ow / resize)
sy = (oh / resize)
x_, y_, w_, h_ = r['rect']
x = np.clip(x_ * sx, 0, ow).astype(int)
y = np.clip(y_ * sy, 0, oh).astype(int)
w = np.clip(w_ * sx, 0, ow).astype(int)
h = np.clip(h_ * sy, 0, oh).astype(int)
r['rect'] = (x, y, w, h)
candidates[i] = r['rect']
candidates = list(candidates.values())
print(f"Generated {len(candidates)} bounding boxes. Taking the top {topk}.")
candidates = candidates[:topk]
# Display topk bounding boxes.
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 8))
ax.imshow(img)
for x, y, w, h in candidates:
rect = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='red', linewidth=1)
ax.add_patch(rect)
plt.show()
image_url = 'http://archive.jsonline.com/Services/image.ashx?domain=www.jsonline.com&file=30025294_messykitchen1.jpg&resize=' #@param {type:"string"}
resize = None#@param {type:"raw"}
topk = 50#@param {type:"integer"}
scale = 200#@param {type:"integer"}
sigma = 0.8#@param {type:"number"}
min_size = 50#@param {type:"integer"}
# Download the image from the web.
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)
if resize is not None:
assert isinstance(resize, int), "resize should be an integer."
img = load_image(image_path)
oh, ow = img.shape[:2]
print(f"Image resolution: {oh, ow}")
# Selective search.
img_search = load_image(image_path, resize=resize)
img_lbl, regions = selectivesearch.selective_search(
img_search, scale=scale, sigma=sigma, min_size=min_size)
candidates = OrderedDict()
for i, r in enumerate(regions):
if r['rect'] in candidates:
continue
if r['size'] < 1000:
continue
x, y, w, h = r['rect']
if w / h > 1.5 or h / w > 1.5:
continue
if resize is not None:
sx = (ow / resize)
sy = (oh / resize)
x_, y_, w_, h_ = r['rect']
x = np.clip(x_ * sx, 0, ow).astype(int)
y = np.clip(y_ * sy, 0, oh).astype(int)
w = np.clip(w_ * sx, 0, ow).astype(int)
h = np.clip(h_ * sy, 0, oh).astype(int)
r['rect'] = (x, y, w, h)
candidates[i] = r['rect']
candidates = list(candidates.values())
print(f"Generated {len(candidates)} bounding boxes. Taking the top {topk}.")
candidates = candidates[:topk]
# Display topk bounding boxes.
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 8))
ax.imshow(img)
for x, y, w, h in candidates:
rect = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='red', linewidth=1)
ax.add_patch(rect)
plt.show()
