手写 JDK 1.7 的 HashMap 实现

前言

这里将参考 JDK 1.7 的 HashMap 底层源码,模拟手写一个简易版的 HashMap。

思路

JDK 1.7 是如何处理哈希冲突的?

在 JDK 1.7 中,HashMap 在处理哈希冲突时采用的是链地址法(Separate Chaining)。当发生哈希冲突时,即多个键被映射到了同一个桶(数组位置),HashMap 会将这些键值对存储在同一个桶对应的链表中。具体来说,在 JDK 1.7 中,HashMap 的每个桶(数组位置)实际上是一个链表,每个链表存储了哈希值相同的键值对。当执行 Put 操作时,HashMap 首先会计算键的哈希值,然后确定该键应该存储在数组的哪个位置。如果该位置已经存在了链表,HashMap 就会遍历该链表,检查是否已经存在相同键的键值对。如果存在相同的键,则 HashMap 会更新相应的值;如果不存在相同的键,则 HashMap 会将新的键值对添加到链表的末尾。链地址法的优点是它能够处理哈希冲突,并且在一定程度上保持了 HashMap 的性能。然而,在负载因子较高的情况下,即链表较长的情况下,查询键值对的效率可能会降低,因为需要遍历链表来找到目标键值对。

JDK 1.8+ 是如何处理哈希冲突的?

在 JDK 1.8 之后,HashMap 的实现发生了变化。JDK 1.8 引入了红黑树来替代链表,以改善在负载因子较高时的性能,这种结构称为 “链表与红黑树混合实现”。具体来说,当哈希冲突发生时,如果链表的长度超过一定阈值(默认为 8),HashMap 会将链表转换为红黑树。这样做的目的是为了在链表长度较长时提高查询、插入和删除操作的效率,因为红黑树的时间复杂度更稳定,为 O (log n)。而当链表长度较短时,仍然保持使用链表结构,因为在较短的链表中,链表的遍历效率更高。

提示

从 JDK 8u40 版本开始,引入了 "树化优化"(Tree bins optimization),也就是 JDK 的底层是基于 数组 + 链表 + 红黑树 + 树化优化 来实现 HashMap。这个优化主要是在进行树化操作时,会先判断当前链表长度是否大于等于 8,如果不是,则不会进行树化操作,以节省资源。这个优化主要是为了解决在一些场景下,链表长度虽然超过了阈值,但树化操作并不能带来性能提升的问题。在一些更新的 JDK 版本中,可能还会引入其他优化措施,比如移位操作的优化,以进一步提升 HashMap 的性能。

代码

  • 定义接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public interface MyMap<K, V> {

public V put(K k, V v);

public V get(K k);

public int size();

public interface Entry<K, V> {

public K getKey();

public V getValue();

}

}
  • 基于数组 + 单向链表实现 HashMap
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
public class MyHashMap<K, V> implements MyMap<K, V> {

private static int defaultLength = 16; // 默认容量
private static float defaultLoader = 0.7f; // 加载因子
private Entry<K, V>[] table = null; // 数组
private int size = 0; // 元素数量

public MyHashMap() {
this(defaultLength, defaultLoader);
}

public MyHashMap(int length, float loader) {
defaultLength = length;
defaultLoader = loader;
// 初始化数组
table = new Entry[defaultLength];
}

/**
* 哈希算法
* <p> 哈希算法决定了运行效率(时间复杂度)
*/
private int hash(K k) {
int l = defaultLength;
int i = k.hashCode() % l;
return i > 0 ? i : -i;
}

/**
* 创建节点
*/
private Entry<K, V> newEntry(K k, V v, Entry<K, V> next) {
return new Entry<>(k, v, next);
}

/**
* 查找节点(递归)
*/
private V find(K k, Entry<K, V> entry) {
if (k == entry.getKey() || k.equals(entry.getKey())) {
return entry.getValue();
} else {
if (entry.next != null) {
return find(k, entry.next);
}
}
return null;
}

@Override
public V put(K k, V v) {
int index = hash(k);
Entry<K, V> entry = table[index];
if (entry == null) {
size++;
table[index] = newEntry(k, v, null);
} else {
// 使用链表(单向)来解决哈希冲突问题
table[index] = newEntry(k, v, entry);
}
return table[index].getValue();
}

@Override
public V get(K k) {
int index = hash(k);
if (index >= defaultLength) {
return null;
}
Entry<K, V> entry = table[index];
return null == entry ? null : find(k, entry);
}

@Override
public int size() {
return size;
}

/**
* 定义节点
* <p> 实现了链表(单向)的数据结构
*/
static class Entry<K, V> implements MyMap.Entry<K, V> {

K k; // 键
V v; // 值
Entry<K, V> next; // 下一个节点

public Entry(K k, V v) {
this.k = k;
this.v = v;
this.next = null;
}

public Entry(K k, V v, Entry<K, V> next) {
this.k = k;
this.v = v;
this.next = next;
}

@Override
public K getKey() {
return k;
}

@Override
public V getValue() {
return v;
}

}

}

测试一

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class MyHashMapTest1 {

public static void main(String[] args) {
MyMap<String, Integer> map = new MyHashMap<>();
map.put("C++", 1);
map.put("Java", 2);
map.put("Python", 3);
map.put("Python", 4);

System.out.println(map.get("C++"));
System.out.println(map.get("Java"));
System.out.println(map.get("Python"));
System.out.println("size = " + map.size());
}

}

输出结果:

1
2
3
4
1
2
4
size = 3

测试二

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class MyHashMapTest2 {

public static void main(String[] args) {
MyMap<String, Integer> map = new MyHashMap<>();
long star = System.currentTimeMillis();
for (int i = 0; i < 1000; i++) {
map.put("Java" + i, i);
}
for (int i = 0; i < 1000; i++) {
map.get("Java");
}
long end = System.currentTimeMillis();
System.out.println("耗时:" + (end - star) + "ms");
}

}

输出结果:

1
耗时:13ms