Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore use of Java Vector API for SIMD support #739

Open
headius opened this issue Jan 29, 2025 · 3 comments
Open

Explore use of Java Vector API for SIMD support #739

headius opened this issue Jan 29, 2025 · 3 comments
Assignees

Comments

@headius
Copy link
Contributor

headius commented Jan 29, 2025

Inspired by @samyron's work in #730, I'd like to explore the potential of Java's Vector API in Psych.

https://openjdk.org/jeps/489

The API has been gestating for many years, but can be enabled and used on all recent JDKs. The potential here is to get SIMD performance without having to write platform-specific code, and enable it only when the Vector API is enabled at the JVM level.

This could also be a fun project for someone else who wants to play with truly bleeding-edge JVM features and help out JRuby at the same time.

I would love to hear from @samyron about more ideas for SIMD optimization of Psych, and try to implement as many of those ideas as possible in the JRuby extension.

@samyron
Copy link

samyron commented Jan 30, 2025

I'm happy to take a look. I haven't looked at the Java Vector API. However, it might be easier to implement similar ideas as to what I did in #730 as the JVM will handle CPU/ISA detection.

@headius
Copy link
Contributor Author

headius commented Mar 12, 2025

The API in question has been incubating for many years in JDK and is still experimental, but information on the ninth version of the API is here: https://openjdk.org/jeps/489

The API looks pretty straightforward to use:

static final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_PREFERRED;

void vectorComputation(float[] a, float[] b, float[] c) {
    int i = 0;
    int upperBound = SPECIES.loopBound(a.length);
    for (; i < upperBound; i += SPECIES.length()) {
        // FloatVector va, vb, vc;
        var va = FloatVector.fromArray(SPECIES, a, i);
        var vb = FloatVector.fromArray(SPECIES, b, i);
        var vc = va.mul(va)
                   .add(vb.mul(vb))
                   .neg();
        vc.intoArray(c, i);
    }
    for (; i < a.length; i++) {
        c[i] = (a[i] * a[i] + b[i] * b[i]) * -1.0f;
    }
}

Other examples and the resulting assembly are in this and other documentation.

I'll have a look at what you did in #730 and see if I can move things in the right direction.

@samyron
Copy link

samyron commented Mar 14, 2025

Apologies... this fell off my radar. This might be a good starting point:

package vectortest;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;

public class Main {
	public static void main(String[] args) {
		String str = "This \"is\" a test of the \"emergency\" broadcast system. Do not be alarmed.";

		VectorSpecies<Byte> species = ByteVector.SPECIES_PREFERRED;
		System.out.println(species);

		Vector<Byte> space = species.broadcast(' ');
		Vector<Byte> backslash = species.broadcast('\\');
		Vector<Byte> doubleQuote = species.broadcast('\"');

		byte[] bytes = str.getBytes();
		int offset = 0;
		while (offset + species.length() < bytes.length) {
			ByteVector chunk = ByteVector.fromArray(species, bytes, offset);
			System.out.println(chunk);

			VectorMask<Byte> mask1 = chunk.lt(space);
			VectorMask<Byte> mask2 = chunk.eq(backslash);
			VectorMask<Byte> mask3 = chunk.eq(doubleQuote);

			VectorMask<Byte> needsEscape = mask1.or(mask2).or(mask3);
			System.out.println(needsEscape);

			if (needsEscape.anyTrue()) {
				System.out.println("Some byte(s) in this chunk need to be escaped.");
			}

			offset += species.length();
		}

		for (int i = offset; i < bytes.length; i++) {
			byte b = bytes[i];
			if ((b < ' ') || (b == '\\') || (b == '\"')) {
				System.out.println("Need to escape this byte.");
			}
		}
	}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants