Skip to content

Visualize inference results

This article explains how to visualize inference results for the following project types:

  • Image classification
  • Multilabel image classification
  • Object detection
  • Instance and semantic segmentation

To set up the basic visualization of inference results, you will need two scripts:

  • api_inference.py, which interacts with an API to retrieve inference results.
  • utils.py, which contains utility functions to process and visualize the returned data.

Retrieve inference results

The api_inference.py example script reads an input image, performs platform inference over the API, and visualizes the results. An image stored on a local device is provided as the input. The output is a processed image with visualized inference results.

The input parameters are configured in the CONFIG dictionary with the following parameters:

  • img_path: Path to the input image for inference.
  • api_endpoint: Robovision AI inference API endpoint.
api_inference.py
import requests 
import cv2
import json
import numpy as np
import os
from PIL import Image

from utils import ColorGenerator, create_output_image  # Import utility functions for color generation and image processing

# Single label classification
CONFIG_SL_CLF = {
    "img_path" : "<path/to/image>",
    "api_endpoint" : "<deployment_api_endpoint>",
}

# Multi label classification
CONFIG_ML_CLF = {
    "img_path" : "<path/to/image>",
    "api_endpoint" : "<deployment_api_endpoint>",
}

# Object detection 
CONFIG_OD = {
    "img_path" : "<path/to/image>",
    "api_endpoint" : "<deployment_api_endpoint>",
}

# Segmentation
CONFIG_SE = {
    "img_path" : "<path/to/image>",
    "api_endpoint" : "<deployment_api_endpoint>",
}

if __name__ == '__main__':

    # Initialize session parameters
    color_generator = ColorGenerator()  # Generates colors for visualization
    session = requests.session()  # Create a persistent HTTP session

    # 1. Read input image 
    img_path = CONFIG_SL_CLF["img_path"]

    # Ensure the provided image path exists
    while not os.path.exists(img_path):
        img_path = input("The provided path to the image does not exist. Please provide a valid path: ")

    input_image = cv2.imread(img_path)  # Load the image using OpenCV

    # Uncomment to display the input image before processing
    # print(f"Displaying image: {img_path}")
    # cv2.imshow("Input Image", input_image)

    # 2. Perform inference by sending the image to the API
    with open(img_path, 'rb') as f:
        response = requests.post(f"{CONFIG_SL_CLF['api_endpoint']}/predict", files={'image': f.read()})

        # Check if the request was successful
        if response.status_code == 200:
            json_response = response.json()  # Parse the response JSON

        else:
            print(f"Request failed with status code: {response.status_code}. Reason: {response.reason}")   
            exit()  # Exit the script if the request fails

    # 3. Visualize inference results if the response is not empty
    resulting_image = create_output_image(input_image, json_response, color_generator)

    # Display the processed image with inference results
    cv2.imshow("Inference Result", resulting_image)
    cv2.waitKey(0)  # Wait for a key press before closing the window
    cv2.destroyAllWindows()  # Close the image display window

Visualize inference results

The visualization functionality is centralized in the utils.py file. To generate an output image, use the create_output_image(...) function as the primary entry point.

utils.py
import base64
import cv2
import numpy as np
import random
from config_dashboard import DashboardConfig
import requests
import json
from pycocotools.mask import decode
from json.decoder import JSONDecodeError


class ColorGenerator:
    """
    Custom class used for handling colors 
    """
    def __init__(self):
        self.class_color_dict: dict = {}

    def set_color(self, class_name: str):
        idx_class = len(self.class_color_dict)
        if(idx_class < len(DashboardConfig.class_colors)):
            # First colors are in config file to ensure clear contrast
            color = DashboardConfig.class_colors[idx_class]
            self.class_color_dict.update({class_name: color})
            return color 
        else: 
            # Random color
            color = (
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255)
            )         
            self.class_color_dict.update({class_name: color})
            return color

    def get_color(self, class_name: str):
        if class_name not in self.class_color_dict:
            # Add new color
            color = self.set_color(class_name)

        else:
            # Return Existing color 
            color = self.class_color_dict[class_name]

        return color


"""
Helper functions 
"""
def decode_image(image: str) -> np.ndarray:
    """Helper function to decode a string to an image represented by an numpy array"""
    bytes_image = base64.b64decode(image)
    decoded_image = cv2.imdecode(np.frombuffer(bytes_image, dtype=np.uint8), cv2.IMREAD_COLOR)
    decoded_image = decoded_image[:, :, ::-1]

    return decoded_image


def create_output_image(image: np.ndarray, json_response: dict, color_generator: ColorGenerator) -> np.ndarray:
    """
    General function for creating the output output image with inference results 
    """
    result = None 

    # Get the project type 
    project_type = get_project_type(json_response)

    # Visualize Inference results

    if(len(json_response) == 0):
        # Empty response, return input image
        return image

    elif(project_type == "SLCL"):
        # Single Label Classification 
        result = get_SL_CLF_output(image, json_response, color_generator)


    elif(project_type == "MLCL"):
        # Multi Label Classification 
        result = get_ML_CLF_output(image, json_response, color_generator)

    elif(project_type == "OD"):
        # Object Detection
        result = get_OD_output(image, json_response, color_generator)

    elif(project_type == "SE"):
        # Segmentation
        result = get_SE_output(image, json_response, color_generator) 

    else: 
        print("Project type not supported. Exiting application")
        exit()

    return result  

def create_output_centroid(image: np.ndarray, json_response: dict, color_generator: ColorGenerator) -> np.ndarray:
    """
    General function for creating the output output image with inference results 
    """
    result = None 

    # Get the project type 
    project_type = get_project_type(json_response)

    # Visualize inference results

    if(len(json_response) == 0):
        # Empty response, return input image
        return image

    elif(project_type == "SLCL"):
        # Single label classification 
        result = get_SL_CLF_output(image, json_response, color_generator)


    elif(project_type == "MLCL"):
        # Multi label classification 
        result = get_ML_CLF_output(image, json_response, color_generator)

    elif(project_type == "OD"):
        # Object detection
        result = get_OD_output_centroid(image, json_response, color_generator)

    elif(project_type == "SE"):
        # Segmentation
        result = get_SE_output_centroid(image, json_response, color_generator) 

    else: 
        print("Project type not supported. Exiting application")
        exit()

    return result  


def get_SL_CLF_output(image: np.ndarray, prediction: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Single label classification: Returns the image with the class label & confidence
    """
    image_copy = image.copy()
    x,y, = 10,20  # Start position of labels 
    label = prediction["value"]["id"] 
    confidence = prediction["confidence"]
    label_color = color_generator.get_color(label)

    text_parameters = {
        'text': f"{label}: {confidence}",
        'fontFace': cv2.FONT_HERSHEY_DUPLEX,
        'thickness': 1,
        'fontScale': 0.8,
    }

    # Draw rectangle behind text
    text_dims, baseline = cv2.getTextSize(**text_parameters)
    color = (label_color[2], label_color[1], label_color[0])       # BGR -> RGB
    vpad_rect = 10
    width_rect = text_dims[0] + 20
    image_copy = cv2.rectangle(image_copy, (x, y - text_dims[1] - vpad_rect), (x + width_rect, y + vpad_rect), color, -1)

    # Draw text
    image_copy = cv2.putText(
        img=image_copy,
        **text_parameters,
        org=(x, y),
        color=(255, 255, 255),
    )

    return image_copy 


def get_ML_CLF_output(image: np.ndarray, predictions: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Multi label classification: Returns the image with the class labels & confidences
    """
    image_copy = image.copy()
    x,y, = 10,20  # Start position of labels 

    # Mutli label classifications contains multiple labels & confidence values 
    for prediction in predictions:
        label = prediction["value"]["id"] 
        confidence = prediction["confidence"]
        label_color = color_generator.get_color(label)

        text_parameters = {
            'text': f"{label}: {confidence}",
            'fontFace': cv2.FONT_HERSHEY_DUPLEX,
            'thickness': 1,
            'fontScale': 0.8,
        }

        # Draw rectangle behind text
        text_dims, baseline = cv2.getTextSize(**text_parameters)
        color = (label_color[2], label_color[1], label_color[0])       # BGR -> RGB
        vpad_rect = 10
        width_rect = text_dims[0] + 20
        image_copy = cv2.rectangle(image_copy, (x, y - text_dims[1] - vpad_rect), (x + width_rect, y + vpad_rect), color, -1)

        # Draw text
        image_copy = cv2.putText(
            img=image_copy,
            **text_parameters,
            org=(x, y),
            color=(255, 255, 255),
        )

        # Increment Y position for the next label
        y += int(1.5 * text_dims[1])

    return image_copy 


def get_OD_output(image: np.ndarray, predictions: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Object detection: Returns the image with bounding boxes, labels & classes
    """
    image_copy = image.copy()

    for prediction in predictions:
        # Iterate over bounding boxes
        x,y = [int(i) for i in prediction["value"]["center"]["position"]]
        h,w = [int(i) for i in prediction["value"]["size"]]
        label = prediction["label"]["id"]
        confidence = prediction["confidence"]
        label_color = color_generator.get_color(label)

        # Draw bounding box
        image_copy = cv2.rectangle(image_copy, (x-int(w/2), y-int(h/2)), (x+int(w/2), y+int(h/2)), label_color, 2)

        # Add label & confidence above the bounding box
        label_text = f"{label}: {confidence:.2f}"
        label_size, _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=4)
        label_position = (x - int(label_size[0]/2), y - int(h/2) - 10)  # Horizontally centered
        image_copy = cv2.putText(image_copy, label_text, label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, label_color, 1, cv2.LINE_AA)

    return image_copy 

def get_OD_output_centroid(image: np.ndarray, predictions: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Object detection: Returns the image with bounding boxes, labels, classes, 
    and marks the center of each bounding box with a cross and center coordinates.
    """
    image_copy = image.copy()

    for prediction in predictions:
        # Extract bounding box and label info
        x, y = [int(i) for i in prediction["value"]["center"]["position"]]
        h, w = [int(i) for i in prediction["value"]["size"]]
        label = prediction["label"]["id"]
        confidence = prediction["confidence"]
        label_color = color_generator.get_color(label)

        # Draw bounding box
        image_copy = cv2.rectangle(image_copy, 
                                   (x - int(w / 2), y - int(h / 2)), 
                                   (x + int(w / 2), y + int(h / 2)), 
                                   label_color, 2)

        # Add label & confidence above the bounding box
        label_text = f"{label}: {confidence:.2f}"
        label_size, _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)
        label_position = (x - int(label_size[0] / 2), y - int(h / 2) - 10)  # Horizontally centered
        image_copy = cv2.putText(image_copy, label_text, label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, label_color, 1, cv2.LINE_AA)

        # Draw cross at the center of the bounding box
        cross_size = 5
        image_copy = cv2.line(image_copy, (x - cross_size, y), (x + cross_size, y), label_color, 1)
        image_copy = cv2.line(image_copy, (x, y - cross_size), (x, y + cross_size), label_color, 1)

        # Add center coordinates near the cross
        coord_text = f"({x}, {y})"
        coord_size, _ = cv2.getTextSize(coord_text, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)
        coord_position = (x + 10, y + 5)  # Offset to avoid overlap with the cross
        image_copy = cv2.putText(image_copy, coord_text, coord_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, label_color, 1, cv2.LINE_AA)

    return image_copy


def get_SE_output(image: np.ndarray, predictions: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Segmentation: Returns the image with masks
    """
    image_copy = image.copy()

    lst_masks = []
    for prediction in predictions:
        mask_label = prediction["label"]["id"]

        # Decode mask (RLE Encoded)
        # Platform inference: Mask is RLE Encoded 
        mask_decoded = decode(prediction["value"])

        lst_masks.append({'mask': mask_decoded, 'name': mask_label})

    # Generate a black image with colored masks. 
    mask_shape = lst_masks[0]['mask'].shape
    black_background = np.zeros((mask_shape[0], mask_shape[1], 3), dtype='uint8')

    # Iterate over all detected masks.
    for mask_dict in lst_masks:
        # Get the mask and the mask color.
        mask = mask_dict['mask']
        mask_classname = mask_dict['name']
        mask_color = color_generator.get_color(mask_classname)

        # Fill in the pixels in the black image corresponding with the mask with the appropriate color.
        black_background[mask > 0] = mask_color

        # If specified, draw the contours of the mask on the target image.
        if True:
            contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(image_copy, contours, -1, mask_color, 3)

    # Overlay the (originally black) mask image with the source image.
    bin_mask = (black_background > 0).all(axis=-1)
    alpha = 0.1
    image_copy[bin_mask] = (image_copy[bin_mask].astype('float') * (1-alpha) +
                            black_background[bin_mask].astype('float') * alpha).astype('uint8')

    return image_copy

def get_SE_output_centroid(image: np.ndarray, predictions: list, color_generator: ColorGenerator) -> np.ndarray:
    """
    Segmentation: Returns the image with masks and centroids
    """
    image_copy = image.copy()

    lst_masks = []
    for prediction in predictions:
        mask_label = prediction["label"]["id"]

        # Decode mask (RLE Encoded)
        # Platform inference: Mask is RLE Encoded
        mask_decoded = decode(prediction["value"])

        lst_masks.append({'mask': mask_decoded, 'name': mask_label})

    # Generate a black image with colored masks.
    mask_shape = lst_masks[0]['mask'].shape
    black_background = np.zeros((mask_shape[0], mask_shape[1], 3), dtype='uint8')

    # Iterate over all detected masks.
    for mask_dict in lst_masks:
        # Get the mask and the mask color.
        mask = mask_dict['mask']
        mask_classname = mask_dict['name']
        mask_color = color_generator.get_color(mask_classname)

        # Fill in the pixels in the black image corresponding with the mask with the appropriate color.
        black_background[mask > 0] = mask_color

        # If specified, draw the contours of the mask on the target image.
        if True:
            contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(image_copy, contours, -1, mask_color, 3)

        # Calculate and draw the centroid
        M = cv2.moments(mask)
        if M["m00"] != 0:  # To avoid division by zero
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])

            # Draw a cross at the centroid
            cross_size = 10
            thickness = 2
            cv2.line(image_copy, (cX - cross_size, cY), (cX + cross_size, cY), mask_color, thickness)
            cv2.line(image_copy, (cX, cY - cross_size), (cX, cY + cross_size), mask_color, thickness)

    # Overlay the (originally black) mask image with the source image.
    bin_mask = (black_background > 0).all(axis=-1)
    alpha = 0.1
    image_copy[bin_mask] = (image_copy[bin_mask].astype('float') * (1 - alpha) +
                            black_background[bin_mask].astype('float') * alpha).astype('uint8')

    return image_copy


def draw_class_label(image: np.ndarray, class_label: str, class_score: float = None):
    """
    Classification: Draw class label on image 
    """
    VISUALIZATION_HEIGHT = 500
    h, w = image.shape[:2]
    scale_factor = VISUALIZATION_HEIGHT / h
    image_copy = image.copy()

    image_copy = cv2.resize(image_copy, (int(scale_factor * w), VISUALIZATION_HEIGHT))

    text = class_label if class_score is None else f"{class_label}: {class_score:.2f}"

    text_parameters = {
        'text': text,
        'fontFace': cv2.FONT_HERSHEY_DUPLEX,
        'thickness': 1,
        'fontScale': 0.8,
    }

    text_dims, baseline = cv2.getTextSize(**text_parameters)
    image_copy = cv2.rectangle(image_copy, (0,0), (text_dims[0], int(1.5*text_dims[1])), (156,32,240), -1)

    image_copy = cv2.putText(
        img=image_copy,
        **text_parameters,
        org=(0, text_dims[1] + 3),
        color=(255, 255, 255),
    )

    return image_copy


def get_project_type(response) -> str:
    """
    Determines project type based on response keys
    NOTE: Can be optimization
    """

    # Get the keys for making project type determination
    if(type(response) == list): 
        response_keys = list(response[0]["value"].keys())
    else: 
        response_keys = list(response["value"].keys())

    if(response_keys == ["id"]): 
        # Classificadtion
        if(type(response) == list):
            return "MLCL"
        else: 
            return "SLCL"


    elif(response_keys == ["center","size"]):
        # Object Detection
        return "OD"
    elif(response_keys == ["size","counts"]):
        # Segmenation
        return "SE"

    else: 
        return "UNK"