微智科技网
您的当前位置:首页【pytorch】StatScores的原理与使用

【pytorch】StatScores的原理与使用

来源:微智科技网

Confusion matrix (混淆矩阵)

在介绍StatScores之前,我们先复习以下Confusion matrix。

我们有两组数据,分别为真实分布预测分布
预测为真定义为Possitive,预测为假定义为Negetive

四分类定义

关系图

额外提一下Precision与Recall

Precision(准确率) 与 Recall (召回率)

P r e c i s i o n = T P T P + F P Precision = \cfrac {TP} {TP+FP} Precision=TP+FPTP
R e c a l l = T P T P + F N Recall = \cfrac {TP} {TP+FN} Recall=TP+FNTP


StatScores类

继承关系

直接继承与

class StatScores(Metric)

四类任务

它将处理的case分为了四类

  1. Binary 二分类
  2. MultiClass 多分类
  3. MultiLabel 多标签
  4. MultiClass&MultiLabel

没有入参指定所属的任务case,代码中是根据pred张量来判断的。逻辑如下,

因为笔者暂时只使用第1和2中,所以其他暂不介绍了。

Update与Compute方法

所有继承Metrics的子类都需要实现Update和Compute方法。

1. update

该方法主要作用是将preds和target做one hot化,所属分类任务的case也在该方法中识别的。

_input_format_classification的四个参数

这里有三个参数注意以下:

  • threshold
    它仅仅作用与Binary的任务,作用是preds张量中,如果元素大于threshold,则规整为1,否则规整为0
  • num_classes
    指明分类种类,如果不指明的话,代码中根据元素值的最大值来判断。这个值同时也会影响one_hot后的数据长度。
  • multiclass
    如果multiclass=False,则强制认为所属任务为Binary。True或者不设置(None)则根据入参自行判断
  • topk
    在多分类任务中,在做one_hot转换时,需要返回的最大前k个位置。
    比如[0.1,0.5,0.4], 在topk=1(默认时),返回的是 [0,1,0],
    如果topk=2,则返回的是[0,1,1]

_stat_scores

_stat_scores是真实计算tp, fp, tn, fn四个值的地方。

举个例子

假设我们有如下

preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

首先,在 _input_format_classification方法处理后,这两个张量会转换为one_hot形式如下,

preds = [[1,0], [0,1], [1,0]]
target= [[0,1], [0,1], [1,0]]

然后, 进入**_stat_scores**
第,65行的计算结果如下:

# 预测true是正确的预测值和预测是false是正确的预测值
true_pred, false_pred = [[False,False], [True, True], [True, True]] , 
                           [ [True True], [False, False] [False, False]
# 预测是Ture的预测值与预测是False的预测值
pos_pred, neg_pred = [[False, True] [False, True] [True, False]] , 
                           [[True False] [True False] [True False]]

这两者再两两相乘,得到tp fp tn fn

    tp = (true_pred * pos_pred).sum(dim=dim)
    fp = (false_pred * pos_pred).sum(dim=dim)

    tn = (true_pred * neg_pred).sum(dim=dim)
    fn = (false_pred * neg_pred).sum(dim=dim)

2. compute

compute调用内部方法 _stat_scores_compute

_stat_scores_compute

该方法返回一个数组, [tp, fp, tn, fn, tp_fn]

因篇幅问题不能全部显示,请点此查看更多更全内容