Playing OpenAI CLIP

Mistakes To Avoid When Displaying Errors On Forms

Part1: Zero-shot Detection

The versatility of CLIP is quite amazing. Below, I want to demonstrate its zero-shot capabilities on various tasks, such as text-prompted detection.

CLIP's Performance

Mistakes To Avoid When Displaying Errors On Forms

Training Efficiency:

CLIP is among one of the most efficient models with an accuracy of 41% at 400 million images, outperforming other models such as the Bag of Words Prediction (27%) and the Transformer Language Model (16%) at the same number of images. This means that CLIP trains much faster than other models within the same domain.

Generalization: CLIP has been trained with such a wide array of image styles that it is far more flexible and than other models like ImageNet. It is important to note that CLIP generalizes well with images that it was trained on, not images outside of its training domain.

Automatically generate proposal regions with selective search, compute their similarity with a natural language query in CLIP embedding space, and return the top-k detections with non-maximum suppression.

%%capture
!pip install ftfy regex tqdm matplotlib selectivesearch
!pip install git+https://github.com/openai/CLIP.git
 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import clip
from PIL import Image
from torchvision import transforms
import selectivesearch
from collections import OrderedDict
 
def load_image(img_path, resize=None, pil=False):
    image = Image.open(image_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    if pil:
        return image
    image = np.asarray(image).astype(np.float32) / 255.
    return image
 
# Reference: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]
 
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]
 
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
 
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)
 
        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]
 
    return keep
 
# Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py
def vis_detections(im, dets, thresh=0.5, caption=None):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
 
    top_idx = dets[:, -1].argmax()
 
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
 
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red' if i == top_idx else 'green', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:.3f}'.format(score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    if caption is not None:
        plt.title(caption, fontsize=20)
    plt.show()
 
 
%%capture
!pip install ftfy regex tqdm matplotlib selectivesearch
!pip install git+https://github.com/openai/CLIP.git
 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import clip
from PIL import Image
from torchvision import transforms
import selectivesearch
from collections import OrderedDict
 
def load_image(img_path, resize=None, pil=False):
    image = Image.open(image_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    if pil:
        return image
    image = np.asarray(image).astype(np.float32) / 255.
    return image
 
# Reference: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]
 
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]
 
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
 
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)
 
        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]
 
    return keep
 
# Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py
def vis_detections(im, dets, thresh=0.5, caption=None):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
 
    top_idx = dets[:, -1].argmax()
 
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
 
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red' if i == top_idx else 'green', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:.3f}'.format(score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    if caption is not None:
        plt.title(caption, fontsize=20)
    plt.show()
 
 
 
image_url = 'http://archive.jsonline.com/Services/image.ashx?domain=www.jsonline.com&file=30025294_messykitchen1.jpg&resize=' #@param {type:"string"}
resize = None#@param {type:"raw"}
topk = 50#@param {type:"integer"}
scale =  200#@param {type:"integer"}
sigma =  0.8#@param {type:"number"}
min_size = 50#@param {type:"integer"}
 
# Download the image from the web.
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)
 
if resize is not None:
    assert isinstance(resize, int), "resize should be an integer."
 
img = load_image(image_path)
oh, ow = img.shape[:2]
print(f"Image resolution: {oh, ow}")
 
# Selective search.
img_search = load_image(image_path, resize=resize)
img_lbl, regions = selectivesearch.selective_search(
    img_search, scale=scale, sigma=sigma, min_size=min_size)
candidates = OrderedDict()
for i, r in enumerate(regions):
    if r['rect'] in candidates:
        continue
    if r['size'] < 1000:
        continue
    x, y, w, h = r['rect']
    if w / h > 1.5 or h / w > 1.5:
        continue
    if resize is not None:
        sx = (ow / resize)
        sy = (oh / resize)
        x_, y_, w_, h_ = r['rect']
        x = np.clip(x_ * sx, 0, ow).astype(int)
        y = np.clip(y_ * sy, 0, oh).astype(int)
        w = np.clip(w_ * sx, 0, ow).astype(int)
        h = np.clip(h_ * sy, 0, oh).astype(int)
        r['rect'] = (x, y, w, h)
    candidates[i] = r['rect']
candidates = list(candidates.values())
print(f"Generated {len(candidates)} bounding boxes. Taking the top {topk}.")
candidates = candidates[:topk]
 
# Display topk bounding boxes.
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 8))
ax.imshow(img)
for x, y, w, h in candidates:
    rect = mpatches.Rectangle(
        (x, y), w, h, fill=False, edgecolor='red', linewidth=1)
    ax.add_patch(rect)
plt.show()
 
 
image_url = 'http://archive.jsonline.com/Services/image.ashx?domain=www.jsonline.com&file=30025294_messykitchen1.jpg&resize=' #@param {type:"string"}
resize = None#@param {type:"raw"}
topk = 50#@param {type:"integer"}
scale =  200#@param {type:"integer"}
sigma =  0.8#@param {type:"number"}
min_size = 50#@param {type:"integer"}
 
# Download the image from the web.
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)
 
if resize is not None:
    assert isinstance(resize, int), "resize should be an integer."
 
img = load_image(image_path)
oh, ow = img.shape[:2]
print(f"Image resolution: {oh, ow}")
 
# Selective search.
img_search = load_image(image_path, resize=resize)
img_lbl, regions = selectivesearch.selective_search(
    img_search, scale=scale, sigma=sigma, min_size=min_size)
candidates = OrderedDict()
for i, r in enumerate(regions):
    if r['rect'] in candidates:
        continue
    if r['size'] < 1000:
        continue
    x, y, w, h = r['rect']
    if w / h > 1.5 or h / w > 1.5:
        continue
    if resize is not None:
        sx = (ow / resize)
        sy = (oh / resize)
        x_, y_, w_, h_ = r['rect']
        x = np.clip(x_ * sx, 0, ow).astype(int)
        y = np.clip(y_ * sy, 0, oh).astype(int)
        w = np.clip(w_ * sx, 0, ow).astype(int)
        h = np.clip(h_ * sy, 0, oh).astype(int)
        r['rect'] = (x, y, w, h)
    candidates[i] = r['rect']
candidates = list(candidates.values())
print(f"Generated {len(candidates)} bounding boxes. Taking the top {topk}.")
candidates = candidates[:topk]
 
# Display topk bounding boxes.
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 8))
ax.imshow(img)
for x, y, w, h in candidates:
    rect = mpatches.Rectangle(
        (x, y), w, h, fill=False, edgecolor='red', linewidth=1)
    ax.add_patch(rect)
plt.show()
 
Mistakes To Avoid When Displaying Errors On Forms