Add Decoder#ofPredicate, Decoder#contramap.

Change-Id: Ifd4e372a6a3c3028d1cd74e6d9a0145c3f571ff5
diff --git a/jgvariant-core/src/main/java/eu/mulk/jgvariant/core/Decoder.java b/jgvariant-core/src/main/java/eu/mulk/jgvariant/core/Decoder.java
index 33b0480..3beb247 100644
--- a/jgvariant-core/src/main/java/eu/mulk/jgvariant/core/Decoder.java
+++ b/jgvariant-core/src/main/java/eu/mulk/jgvariant/core/Decoder.java
@@ -15,8 +15,11 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.function.UnaryOperator;
 import org.apiguardian.api.API;
 import org.apiguardian.api.API.Status;
 import org.jetbrains.annotations.Nullable;
@@ -101,6 +104,29 @@
   }
 
   /**
+   * Creates a new {@link Decoder} from an existing one by applying a function to the input.
+   *
+   * @param function the function to apply.
+   * @return a new, decorated {@link Decoder}.
+   * @see java.util.stream.Stream#map
+   */
+  public final Decoder<T> contramap(UnaryOperator<ByteBuffer> function) {
+    return new ContramappingDecoder(function);
+  }
+
+  /**
+   * Creates a new {@link Decoder} that delegates to one of two other {@link Decoder}s based on a
+   * condition on the input {@link ByteBuffer}.
+   *
+   * @param selector the predicate to use to determine the decoder to use.
+   * @return a new {@link Decoder}.
+   */
+  public static <U> Decoder<U> ofPredicate(
+      Predicate<ByteBuffer> selector, Decoder<U> thenDecoder, Decoder<U> elseDecoder) {
+    return new PredicateDecoder<>(selector, thenDecoder, elseDecoder);
+  }
+
+  /**
    * Creates a {@link Decoder} for an {@code Array} type.
    *
    * @param elementDecoder a {@link Decoder} for the elements of the array.
@@ -773,7 +799,7 @@
 
     private final Function<T, U> function;
 
-    public MappingDecoder(Function<T, U> function) {
+    MappingDecoder(Function<T, U> function) {
       this.function = function;
     }
 
@@ -793,11 +819,36 @@
     }
   }
 
+  private class ContramappingDecoder extends Decoder<T> {
+
+    private final UnaryOperator<ByteBuffer> function;
+
+    ContramappingDecoder(UnaryOperator<ByteBuffer> function) {
+      this.function = function;
+    }
+
+    @Override
+    public byte alignment() {
+      return Decoder.this.alignment();
+    }
+
+    @Override
+    public @Nullable Integer fixedSize() {
+      return Decoder.this.fixedSize();
+    }
+
+    @Override
+    public T decode(ByteBuffer byteSlice) {
+      var transformedBuffer = function.apply(byteSlice.asReadOnlyBuffer().order(byteSlice.order()));
+      return Decoder.this.decode(transformedBuffer);
+    }
+  }
+
   private class ByteOrderFixingDecoder extends Decoder<T> {
 
     private final ByteOrder byteOrder;
 
-    public ByteOrderFixingDecoder(ByteOrder byteOrder) {
+    ByteOrderFixingDecoder(ByteOrder byteOrder) {
       this.byteOrder = byteOrder;
     }
 
@@ -822,4 +873,46 @@
   private static ByteBuffer slicePreservingOrder(ByteBuffer byteSlice, int index, int length) {
     return byteSlice.slice(index, length).order(byteSlice.order());
   }
+
+  private static class PredicateDecoder<U> extends Decoder<U> {
+
+    private final Predicate<ByteBuffer> selector;
+    private final Decoder<U> thenDecoder;
+    private final Decoder<U> elseDecoder;
+
+    PredicateDecoder(
+        Predicate<ByteBuffer> selector, Decoder<U> thenDecoder, Decoder<U> elseDecoder) {
+      this.selector = selector;
+      this.thenDecoder = thenDecoder;
+      this.elseDecoder = elseDecoder;
+      if (thenDecoder.alignment() != elseDecoder.alignment()) {
+        throw new IllegalArgumentException(
+            "incompatible alignments in predicate branches: then=%d, else=%d"
+                .formatted(thenDecoder.alignment(), elseDecoder.alignment()));
+      }
+
+      if (!Objects.equals(thenDecoder.fixedSize(), elseDecoder.fixedSize())) {
+        throw new IllegalArgumentException(
+            "incompatible sizes in predicate branches: then=%s, else=%s"
+                .formatted(thenDecoder.fixedSize(), elseDecoder.fixedSize()));
+      }
+    }
+
+    @Override
+    public byte alignment() {
+      return thenDecoder.alignment();
+    }
+
+    @Override
+    public @Nullable Integer fixedSize() {
+      return thenDecoder.fixedSize();
+    }
+
+    @Override
+    public U decode(ByteBuffer byteSlice) {
+      var b = selector.test(byteSlice);
+      byteSlice.rewind();
+      return b ? thenDecoder.decode(byteSlice) : elseDecoder.decode(byteSlice);
+    }
+  }
 }
diff --git a/jgvariant-core/src/test/java/eu/mulk/jgvariant/core/DecoderTest.java b/jgvariant-core/src/test/java/eu/mulk/jgvariant/core/DecoderTest.java
index efbcafa..8c78692 100644
--- a/jgvariant-core/src/test/java/eu/mulk/jgvariant/core/DecoderTest.java
+++ b/jgvariant-core/src/test/java/eu/mulk/jgvariant/core/DecoderTest.java
@@ -507,4 +507,71 @@
     var decoder = Decoder.ofByteArray().map(bytes -> bytes.length);
     assertEquals(3, decoder.decode(ByteBuffer.wrap(data)));
   }
+
+  @Test
+  void testContramap() {
+    var data = new byte[] {0x0A, 0x0B, 0x0C};
+    var decoder = Decoder.ofByteArray().contramap(bytes -> bytes.slice(1, 1));
+    assertArrayEquals(new byte[] {0x0B}, decoder.decode(ByteBuffer.wrap(data)));
+  }
+
+  @Test
+  void testPredicateTrue() {
+    var data = new byte[] {0x00, 0x01, 0x00};
+    var innerDecoder = Decoder.ofShort().contramap(bytes -> bytes.slice(1, 2).order(bytes.order()));
+    var decoder =
+        Decoder.ofPredicate(
+            byteBuffer -> byteBuffer.get(0) == 0,
+            innerDecoder.withByteOrder(LITTLE_ENDIAN),
+            innerDecoder.withByteOrder(BIG_ENDIAN));
+    assertEquals((short) 1, decoder.decode(ByteBuffer.wrap(data)));
+  }
+
+  @Test
+  void testPredicateFalse() {
+    var data = new byte[] {0x01, 0x01, 0x00};
+    var innerDecoder = Decoder.ofShort().contramap(bytes -> bytes.slice(1, 2).order(bytes.order()));
+    var decoder =
+        Decoder.ofPredicate(
+            byteBuffer -> byteBuffer.get(0) == 0,
+            innerDecoder.withByteOrder(LITTLE_ENDIAN),
+            innerDecoder.withByteOrder(BIG_ENDIAN));
+    assertEquals((short) 256, decoder.decode(ByteBuffer.wrap(data)));
+  }
+
+  @Test
+  void testByteOrder() {
+    var data =
+        new byte[] {
+          0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x04, 0x05, 0x00, 0x00, 0x06, 0x00, 0x07, 0x08,
+          0x00
+        };
+
+    record TestChild(short s1, short s2) {}
+    record TestParent(TestChild tc1, TestChild tc2, TestChild tc3, TestChild tc4) {}
+
+    var decoder =
+        Decoder.ofStructure(
+            TestParent.class,
+            Decoder.ofStructure(TestChild.class, Decoder.ofShort(), Decoder.ofShort())
+                .withByteOrder(LITTLE_ENDIAN),
+            Decoder.ofStructure(TestChild.class, Decoder.ofShort(), Decoder.ofShort())
+                .withByteOrder(BIG_ENDIAN),
+            Decoder.ofStructure(
+                    TestChild.class,
+                    Decoder.ofShort().withByteOrder(LITTLE_ENDIAN),
+                    Decoder.ofShort())
+                .withByteOrder(BIG_ENDIAN),
+            Decoder.ofStructure(
+                    TestChild.class, Decoder.ofShort().withByteOrder(BIG_ENDIAN), Decoder.ofShort())
+                .withByteOrder(LITTLE_ENDIAN));
+
+    assertEquals(
+        new TestParent(
+            new TestChild((short) 1, (short) 2),
+            new TestChild((short) 3, (short) 4),
+            new TestChild((short) 5, (short) 6),
+            new TestChild((short) 7, (short) 8)),
+        decoder.decode(ByteBuffer.wrap(data)));
+  }
 }