Skip to content
KC_notebook
  • AI Chat
  • Code
  • Report
  • import torch 
    
    
    y_pred = torch.Tensor([
        [1, 1, 2, 2, 2],
        [1, 1, 2, 1, 2],
        [1, 0, 0, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 1, 1, 1, 2]
    ])
    
    y_true = torch.Tensor([
        [1, 1, 1, 2, 2],
        [1, 1, 1, 2, 2],
        [1, 1, 1, 2, 2],
        [0, 0, 0, 2, 2],
        [0, 0, 0, 2, 2]
    ])
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    y_pred_0 = y_pred == 0
    y_true_0 = y_true == 0
    intersection_0 = (y_pred_0 & y_true_0).sum()
    union_0 = (y_pred_0 | y_true_0).sum()
    jaccard_0 = intersection_0 / union_0
    
    y_pred_1 = y_pred == 1
    y_true_1 = y_true == 1
    intersection_1 = (y_pred_1 & y_true_1).sum()
    union_1 = (y_pred_1 | y_true_1).sum()
    jaccard_1 = intersection_1 / union_1
    
    y_pred_2 = y_pred == 2
    y_true_2 = y_true == 2
    intersection_2 = (y_pred_2 & y_true_2).sum()
    union_2 = (y_pred_2 | y_true_2).sum()
    jaccard_2 = intersection_2 / union_2
    (jaccard_0 + jaccard_1 + jaccard_2)/3