龙空技术网

分类之K最近邻居(KNN)算法

Java个人学习心得 152

前言:

现在看官们对“javaknn”大约比较珍视,姐妹们都想要了解一些“javaknn”的相关内容。那么小编也在网上搜集了一些有关“javaknn””的相关内容,希望小伙伴们能喜欢,朋友们一起来了解一下吧!

一些基础概念参看《聚类之K均值(K-means)算法》。

至于为什么写KNN算法,单纯就是它跟K-means算法都是K字辈儿的,,,,

聚类是无监督学习,将相似(距离近)的数据“聚”为一“簇”。

分类是监督学习,根据已知的数据集和对应的分类,判定新数据应“分”为哪一“类”。

以KNN来说,就是计算出新数据A离自己最近的K个数据,这K个数据中,哪个类别的最多,就判定A属于哪个类别。

这个k就很重要了,如下图:

如果我们设绿点为新数据,k=x,则篮圈中蓝点最多,绿点归于蓝。

如果k=y,则红圈中红点最多,绿点归于红。

直接上代码:

public class KNN {    public static void main(String[] args) {        //交叉验证        //KNN.crossValidation();        //新节点        // Node newNode = new Node((int) (Math.random() * 100), (int) (Math.random() * 100));        Node newNode = new Node(52, 57);        System.out.println("新节点:[" + newNode.getX() + "," + newNode.getY() + ",3],");        //已分类的节点集        List<Node> nodes = KNN.buildNodes();        //临近的节点数量        int k = 3;        Integer type = KNN.knn(newNode, nodes, k);        System.out.println("临近" + k + "个节点类型数量最多的类型:" + type);    }    public static Integer knn(Node newNode, List<Node> nodes, int k) {        //计算新节点到每个节点的距离        for (int i = 0; i < nodes.size(); i++) {            Node node = nodes.get(i);            int d = Distance.euclidean(newNode, node).intValue();            node.setG(d);        }        //节点距离的优先队列        PriorityQueue<Node> pq = new PriorityQueue<Node>(new Comparator<Node>() {            public int compare(Node o1, Node o2) {                if (o1.getG() > o2.getG()) {                    return 1;                } else if (o1.getG() < o2.getG()) {                    return -1;                } else {                    return 0;                }            }        });        pq.addAll(nodes);        //类型-数量Map        Map<Integer, Integer> typeMap = Maps.newHashMap();        typeMap.put(0, 0);        typeMap.put(1, 0);        typeMap.put(2, 0);        //取出距离最近的K个节点,并计算每个类型的数量        for (int i = 0; i < k; i++) {            Node oldNode = pq.poll();            //System.out.println("[" + oldNode.getX() + "," + oldNode.getY() + "," + oldNode.getCode() + "],");            typeMap.put(oldNode.getCode(), typeMap.get(oldNode.getCode()) + 1);        }        //System.out.println("临近" + k + "个节点类型数量:" + typeMap);        //取出数量最多的类型        PriorityQueue<Map.Entry<Integer, Integer>> pq2 = new PriorityQueue<Map.Entry<Integer, Integer>>(                new Comparator<Map.Entry<Integer, Integer>>() {                    public int compare(Entry<Integer, Integer> o1, Entry<Integer, Integer> o2) {                        return o2.getValue().compareTo(o1.getValue());                    }                });        pq2.addAll(typeMap.entrySet());        Integer type = pq2.poll().getKey();        return type;    }    public static List<Node> buildNodes() {        ArrayList<Node> nodes = Lists.newArrayListWithCapacity(100);        nodes.add(new Node(95, 86, 0));        nodes.add(new Node(62, 96, 0));        nodes.add(new Node(38, 78, 0));        nodes.add(new Node(81, 83, 0));        nodes.add(new Node(58, 86, 0));        nodes.add(new Node(24, 88, 0));        nodes.add(new Node(44, 65, 0));        nodes.add(new Node(34, 74, 0));        nodes.add(new Node(91, 92, 0));        nodes.add(new Node(66, 87, 0));        nodes.add(new Node(90, 89, 0));        nodes.add(new Node(93, 72, 0));        nodes.add(new Node(26, 79, 0));        nodes.add(new Node(67, 97, 0));        nodes.add(new Node(66, 71, 0));        nodes.add(new Node(93, 89, 0));        nodes.add(new Node(41, 73, 0));        nodes.add(new Node(32, 71, 0));        nodes.add(new Node(17, 95, 0));        nodes.add(new Node(47, 52, 0));        nodes.add(new Node(19, 75, 0));        nodes.add(new Node(33, 73, 0));        nodes.add(new Node(47, 68, 0));        nodes.add(new Node(61, 71, 0));        nodes.add(new Node(39, 95, 0));        nodes.add(new Node(56, 99, 0));        nodes.add(new Node(60, 97, 0));        nodes.add(new Node(84, 90, 0));        nodes.add(new Node(25, 85, 0));        nodes.add(new Node(47, 78, 0));        nodes.add(new Node(49, 51, 0));        nodes.add(new Node(35, 63, 0));        nodes.add(new Node(54, 81, 0));        nodes.add(new Node(58, 86, 0));        nodes.add(new Node(16, 94, 0));        nodes.add(new Node(91, 82, 0));        nodes.add(new Node(36, 74, 0));        nodes.add(new Node(22, 96, 0));        nodes.add(new Node(58, 70, 0));        nodes.add(new Node(74, 95, 0));        nodes.add(new Node(80, 3, 1));        nodes.add(new Node(64, 36, 1));        nodes.add(new Node(82, 35, 1));        nodes.add(new Node(74, 38, 1));        nodes.add(new Node(54, 33, 1));        nodes.add(new Node(59, 44, 1));        nodes.add(new Node(98, 33, 1));        nodes.add(new Node(59, 1, 1));        nodes.add(new Node(74, 42, 1));        nodes.add(new Node(70, 42, 1));        nodes.add(new Node(87, 15, 1));        nodes.add(new Node(67, 32, 1));        nodes.add(new Node(90, 42, 1));        nodes.add(new Node(85, 23, 1));        nodes.add(new Node(67, 24, 1));        nodes.add(new Node(49, 3, 1));        nodes.add(new Node(57, 9, 1));        nodes.add(new Node(65, 7, 1));        nodes.add(new Node(62, 54, 1));        nodes.add(new Node(52, 23, 1));        nodes.add(new Node(81, 42, 1));        nodes.add(new Node(83, 59, 1));        nodes.add(new Node(89, 20, 1));        nodes.add(new Node(72, 19, 1));        nodes.add(new Node(61, 48, 1));        nodes.add(new Node(83, 36, 1));        nodes.add(new Node(56, 5, 1));        nodes.add(new Node(83, 1, 1));        nodes.add(new Node(48, 45, 1));        nodes.add(new Node(92, 48, 1));        nodes.add(new Node(93, 51, 1));        nodes.add(new Node(24, 60, 2));        nodes.add(new Node(9, 43, 2));        nodes.add(new Node(35, 46, 2));        nodes.add(new Node(31, 38, 2));        nodes.add(new Node(5, 57, 2));        nodes.add(new Node(10, 32, 2));        nodes.add(new Node(1, 39, 2));        nodes.add(new Node(37, 48, 2));        nodes.add(new Node(31, 44, 2));        nodes.add(new Node(38, 27, 2));        nodes.add(new Node(20, 41, 2));        nodes.add(new Node(19, 52, 2));        nodes.add(new Node(7, 54, 2));        nodes.add(new Node(23, 34, 2));        nodes.add(new Node(3, 6, 2));        nodes.add(new Node(14, 4, 2));        nodes.add(new Node(5, 22, 2));        nodes.add(new Node(34, 51, 2));        nodes.add(new Node(42, 26, 2));        nodes.add(new Node(7, 25, 2));        nodes.add(new Node(4, 8, 2));        nodes.add(new Node(15, 44, 2));        nodes.add(new Node(5, 4, 2));        nodes.add(new Node(39, 23, 2));        nodes.add(new Node(25, 64, 2));        nodes.add(new Node(11, 6, 2));        nodes.add(new Node(27, 15, 2));        nodes.add(new Node(30, 14, 2));        nodes.add(new Node(12, 49, 2));        return nodes;    }    public static void crossValidation() {        //训练集        List<Node> trains = Lists.newArrayListWithCapacity(50);        //结果集        List<Node> results = Lists.newArrayListWithCapacity(50);        //原集合        List<Node> nodes = KNN.buildNodes();        //将原集合随机分配为训练集和结果集,进行交叉验证        for (Node node : nodes) {            int target = (int) (Math.random() * 2);            if (target == 0) {                if (trains.contains(node)) {                    results.add(node);                } else {                    trains.add(node);                }            } else {                if (results.contains(node)) {                    trains.add(node);                } else {                    results.add(node);                }            }        }        //待验证的K值        int k = 20;        for (int i = 3; i < k; i++) {            //K不取偶数            if (i % 2 == 0) {                continue;            }            Map<Boolean, Integer> reMap = Maps.newHashMap();            reMap.put(true, 0);            reMap.put(false, 0);            for (Node node : results) {                int type = KNN.knn(node, trains, i);                reMap.put((type == node.getCode()), reMap.get(type == node.getCode()) + 1);            }            System.out.println("K=" + i + ":" + reMap.toString());        }    }}

解释一下:

1、原始数据是我用之前聚类文章中的代码生成的,偷了个懒。

2、依然使用了欧几里得距离。

3、k的取值很重要,过大,偏差过高,假设k=训练集的总数,那么不管新数据A是什么,必会归类于数据最多的那个类别。如果k过小,假设为1,很容易被异常数据影响。

4、可通过交叉验证法来确定k值,本文选择了Holdout Method。

其实就是训练总集分两部分,一部分是训练子集,一部分是结果集。遍历结果集,用不同的k值(>=3的奇数)和训练子集得出数据的分类,与结果集比对,计算正确率,从而选出k值。

K=3:{false=2, true=42}K=5:{false=3, true=41}K=7:{false=4, true=40}K=9:{false=6, true=38}K=11:{false=5, true=39}K=13:{false=3, true=41}K=15:{false=3, true=41}K=17:{false=5, true=39}K=19:{false=7, true=37}

本文的k=3。

下图是跑的结果:

已知数据

结果

标签: #javaknn