I want to show you my implementation of lightweight and simple in-memory key-value cache mechanism in Java.
Database queries could take a time and it’s a good idea to store frequently used data in the cache to retrieve it faster.
Java caching frameworks like Spring Cache allows to define your own in-memory cache implementation, so you can adopt mine.
But first of all, let’s define criteria for our Java cache implementation:
- store data in memory
- allow putting object by key for some amount of time
- memory usage is not restricted, but cache shouldn’t be a reason for
OutOfMemoryError
- the cache should remove expired objects
- thread-safe
Let’s define an API:
package com.explainjava;
public interface Cache {
void add(String key, Object value, long periodInMillis);
void remove(String key);
Object get(String key);
void clear();
long size();
}
It looks similar to Map
API and I’m gonna use a ConcurrentHashMap
for our example.
Keep reading…
Java Cache Implementation
Let’s take a look at an example first and then I’ll explain my architecture decisions:
package com.explainjava;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.lang.ref.SoftReference;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
public class InMemoryCache implements Cache {
private static final int CLEAN_UP_PERIOD_IN_SEC = 5;
private final ConcurrentHashMap<String, SoftReference<CacheObject>> cache = new ConcurrentHashMap<>();
public InMemoryCache() {
Thread cleanerThread = new Thread(() -> {
while (!Thread.currentThread().isInterrupted()) {
try {
Thread.sleep(CLEAN_UP_PERIOD_IN_SEC * 1000);
cache.entrySet().removeIf(entry -> Optional.ofNullable(entry.getValue()).map(SoftReference::get).map(CacheObject::isExpired).orElse(false));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
cleanerThread.setDaemon(true);
cleanerThread.start();
}
@Override
public void add(String key, Object value, long periodInMillis) {
if (key == null) {
return;
}
if (value == null) {
cache.remove(key);
} else {
long expiryTime = System.currentTimeMillis() + periodInMillis;
cache.put(key, new SoftReference<>(new CacheObject(value, expiryTime)));
}
}
@Override
public void remove(String key) {
cache.remove(key);
}
@Override
public Object get(String key) {
return Optional.ofNullable(cache.get(key)).map(SoftReference::get).filter(cacheObject -> !cacheObject.isExpired()).map(CacheObject::getValue).orElse(null);
}
@Override
public void clear() {
cache.clear();
}
@Override
public long size() {
return cache.entrySet().stream().filter(entry -> Optional.ofNullable(entry.getValue()).map(SoftReference::get).map(cacheObject -> !cacheObject.isExpired()).orElse(false)).count();
}
@AllArgsConstructor
private static class CacheObject {
@Getter
private Object value;
private long expiryTime;
boolean isExpired() {
return System.currentTimeMillis() > expiryTime;
}
}
}
Note that I used some Lombok annotations to generate boilerplate code like @AllArgsConstructor
and @Getter
, you can replace it if you want.
So let me explain what is going on here:
- I took
ConcurrentHashMap
because I have thread-safe requirement. - I used
SoftReference<Object>
as a map value because soft reference guarantees that referenced object will be removed in case of lack of memory beforeOutOfMemory
will be thrown. - In the constructor, I created a daemon thread that scans every 5 seconds and cleans up expired objects, 5 seconds is a random number and you should think about cleaning interval.
What’s wrong with this solution?
- If map contains a big amount of cached objects scan and clean up can take a time because it’s iterating through all values
size()
method takes O(n) time because it needs to filter out expired objects
How can we improve this?
Let’s try to use a queue for removing expired objects.
Caching in Java Using Delay Queue
DelayQueue
allows adding elements to the queue with delay (expiration period in our case), so we can schedule removing of expired objects.
Let’s take a look at code example:
package com.explainjava;
import com.google.common.primitives.Longs;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.lang.ref.SoftReference;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
public class InMemoryCacheWithDelayQueue implements Cache {
private final ConcurrentHashMap<String, SoftReference<Object>> cache = new ConcurrentHashMap<>();
private final DelayQueue<DelayedCacheObject> cleaningUpQueue = new DelayQueue<>();
public InMemoryCacheWithDelayQueue() {
Thread cleanerThread = new Thread(() -> {
while (!Thread.currentThread().isInterrupted()) {
try {
DelayedCacheObject delayedCacheObject = cleaningUpQueue.take();
cache.remove(delayedCacheObject.getKey(), delayedCacheObject.getReference());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
cleanerThread.setDaemon(true);
cleanerThread.start();
}
@Override
public void add(String key, Object value, long periodInMillis) {
if (key == null) {
return;
}
if (value == null) {
cache.remove(key);
} else {
long expiryTime = System.currentTimeMillis() + periodInMillis;
SoftReference<Object> reference = new SoftReference<>(value);
cache.put(key, reference);
cleaningUpQueue.put(new DelayedCacheObject(key, reference, expiryTime));
}
}
@Override
public void remove(String key) {
cache.remove(key);
}
@Override
public Object get(String key) {
return Optional.ofNullable(cache.get(key)).map(SoftReference::get).orElse(null);
}
@Override
public void clear() {
cache.clear();
}
@Override
public long size() {
return cache.size();
}
@AllArgsConstructor
@EqualsAndHashCode
private static class DelayedCacheObject implements Delayed {
@Getter
private final String key;
@Getter
private final SoftReference<Object> reference;
private final long expiryTime;
@Override
public long getDelay(TimeUnit unit) {
return unit.convert(expiryTime - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o) {
return Longs.compare(expiryTime, ((DelayedCacheObject) o).expiryTime);
}
}
}
We need to implement a Delayed
interface and override 2 methods: getDelay(TimeUnit unit)
and compareTo(Delayed o)
.
A getDelay()
method defines a period of time before the object will be available in the queue.
A compareTo()
method should be ordering consistent with getDelay()
.
The delayed object appears in the queue only when getDelay()
value is 0 or negative, so we’re sure that object is removed in time.
Now we can remove all isExpired()
checks at all.
size()
method is up to date and doesn’t require to filter out expired objects as well, now it takes constant time.
It’s not needed to iterate through the whole map to find what to delete.
This implementation is more elegant but needs a little bit more memory because of the queue.
Conclusion
It’s a simple example of self-written in-memory cache, you can use it to store some long-running queries from DB or frequently used data.
My implementation is really simple, for more complex cases you should use distributed cache solutions like Memcached, ehCache etc.
I can’t say what is a best Java cache library, it depends.
You should think about what is the best choice for cache management in your application and make decision.
Do you have any question? Propositions? Ask me.