How To Write Simple In-Memory Cache in Java Tutorial

I want to show you my implementation of lightweight and simple in-memory key-value cache mechanism in Java.

Database queries could take a time and it’s a good idea to store frequently used data in the cache to retrieve it faster.

Java caching frameworks like Spring Cache allows to define your own in-memory cache implementation, so you can adopt mine.

But first of all, let’s define criteria for our Java cache implementation:

  • store data in memory
  • allow putting object by key for some amount of time
  • memory usage is not restricted, but cache shouldn’t be a reason for OutOfMemoryError
  • the cache should remove expired objects
  • thread-safe

Let’s define an API:

package com.explainjava;
 
public interface Cache {
    
    void add(String key, Object value, long periodInMillis);
 
    void remove(String key);
 
    Object get(String key);
 
    void clear();
 
    long size();
}

It looks similar to Map API and I’m gonna use a ConcurrentHashMap for our example.

Keep reading…

Java Cache Implementation

Let’s take a look at an example first and then I’ll explain my architecture decisions:

package com.explainjava;
 
import lombok.AllArgsConstructor;
import lombok.Getter;
 
import java.lang.ref.SoftReference;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
 
public class InMemoryCache implements Cache {
 
    private static final int CLEAN_UP_PERIOD_IN_SEC = 5;
 
    private final ConcurrentHashMap<String, SoftReference<CacheObject>> cache = new ConcurrentHashMap<>();
 
    public InMemoryCache() {
        Thread cleanerThread = new Thread(() -> {
            while (!Thread.currentThread().isInterrupted()) {
                try {
                    Thread.sleep(CLEAN_UP_PERIOD_IN_SEC * 1000);
                    cache.entrySet().removeIf(entry -> Optional.ofNullable(entry.getValue()).map(SoftReference::get).map(CacheObject::isExpired).orElse(false));
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        });
        cleanerThread.setDaemon(true);
        cleanerThread.start();
    }
 
    @Override
    public void add(String key, Object value, long periodInMillis) {
        if (key == null) {
            return;
        }
        if (value == null) {
            cache.remove(key);
        } else {
            long expiryTime = System.currentTimeMillis() + periodInMillis;
            cache.put(key, new SoftReference<>(new CacheObject(value, expiryTime)));
        }
    }
 
    @Override
    public void remove(String key) {
        cache.remove(key);
    }
 
    @Override
    public Object get(String key) {
        return Optional.ofNullable(cache.get(key)).map(SoftReference::get).filter(cacheObject -> !cacheObject.isExpired()).map(CacheObject::getValue).orElse(null);
    }
 
    @Override
    public void clear() {
        cache.clear();
    }
 
    @Override
    public long size() {
        return cache.entrySet().stream().filter(entry -> Optional.ofNullable(entry.getValue()).map(SoftReference::get).map(cacheObject -> !cacheObject.isExpired()).orElse(false)).count();
    }
 
    @AllArgsConstructor
    private static class CacheObject {
 
        @Getter
        private Object value;
        private long expiryTime;
 
        boolean isExpired() {
            return System.currentTimeMillis() > expiryTime;
        }
    }
}

Note that I used some Lombok annotations to generate boilerplate code like @AllArgsConstructor and @Getter, you can replace it if you want.

So let me explain what is going on here:

  • I took ConcurrentHashMap because I have thread-safe requirement.
  • I used SoftReference<Object> as a map value because soft reference guarantees that referenced object will be removed in case of lack of memory before OutOfMemory will be thrown.
  • In the constructor, I created a daemon thread that scans every 5 seconds and cleans up expired objects, 5 seconds is a random number and you should think about cleaning interval.

What’s wrong with this solution?

  • If map contains a big amount of cached objects scan and clean up can take a time because it’s iterating through all values
  • size() method takes O(n) time because it needs to filter out expired objects

How can we improve this?

Let’s try to use a queue for removing expired objects.

Caching in Java Using Delay Queue

DelayQueue allows adding elements to the queue with delay (expiration period in our case), so we can schedule removing of expired objects.

Let’s take a look at code example:



package com.explainjava;
 
import com.google.common.primitives.Longs;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
 
import java.lang.ref.SoftReference;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
 
public class InMemoryCacheWithDelayQueue implements Cache {
 
    private final ConcurrentHashMap<String, SoftReference<Object>> cache = new ConcurrentHashMap<>();
    private final DelayQueue<DelayedCacheObject> cleaningUpQueue = new DelayQueue<>();
 
    public InMemoryCacheWithDelayQueue() {
        Thread cleanerThread = new Thread(() -> {
            while (!Thread.currentThread().isInterrupted()) {
                try {
                    DelayedCacheObject delayedCacheObject = cleaningUpQueue.take();
                    cache.remove(delayedCacheObject.getKey(), delayedCacheObject.getReference());
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        });
        cleanerThread.setDaemon(true);
        cleanerThread.start();
    }
 
    @Override
    public void add(String key, Object value, long periodInMillis) {
        if (key == null) {
            return;
        }
        if (value == null) {
            cache.remove(key);
        } else {
            long expiryTime = System.currentTimeMillis() + periodInMillis;
            SoftReference<Object> reference = new SoftReference<>(value);
            cache.put(key, reference);
            cleaningUpQueue.put(new DelayedCacheObject(key, reference, expiryTime));
        }
    }
 
    @Override
    public void remove(String key) {
        cache.remove(key);
    }
 
    @Override
    public Object get(String key) {
        return Optional.ofNullable(cache.get(key)).map(SoftReference::get).orElse(null);
    }
 
    @Override
    public void clear() {
        cache.clear();
    }
 
    @Override
    public long size() {
        return cache.size();
    }
 
    @AllArgsConstructor
    @EqualsAndHashCode
    private static class DelayedCacheObject implements Delayed {
 
        @Getter
        private final String key;
        @Getter
        private final SoftReference<Object> reference;
        private final long expiryTime;
 
        @Override
        public long getDelay(TimeUnit unit) {
            return unit.convert(expiryTime - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        }
 
        @Override
        public int compareTo(Delayed o) {
            return Longs.compare(expiryTime, ((DelayedCacheObject) o).expiryTime);
        }
    }
}

We need to implement a Delayed interface and override 2 methods: getDelay(TimeUnit unit) and compareTo(Delayed o).

A getDelay() method defines a period of time before the object will be available in the queue.

A compareTo() method should be ordering consistent with getDelay().

The delayed object appears in the queue only when getDelay() value is 0 or negative, so we’re sure that object is removed in time.

Now we can remove all isExpired() checks at all.

size() method is up to date and doesn’t require to filter out expired objects as well, now it takes constant time.

It’s not needed to iterate through the whole map to find what to delete.

This implementation is more elegant but needs a little bit more memory because of the queue.

Conclusion

It’s a simple example of self-written in-memory cache, you can use it to store some long-running queries from DB or frequently used data.

My implementation is really simple, for more complex cases you should use distributed cache solutions like Memcached, ehCache etc.

I can’t say what is a best Java cache library, it depends.

You should think about what is the best choice for cache management in your application and make decision.

Do you have any question? Propositions? Ask me.

Leave a Comment