From 4cd59ce3dbf31ad8e450b017aa9e36c0dcb73c8e Mon Sep 17 00:00:00 2001 From: Jake Wharton Date: Thu, 5 Dec 2024 15:33:23 -0500 Subject: [PATCH] Add streaming support for Wire request bodies --- CHANGELOG.md | 5 +- gradle/libs.versions.toml | 1 + retrofit-converters/wire/build.gradle | 1 + .../converter/wire/WireConverterFactory.java | 26 +++- .../wire/WireRequestBodyConverter.java | 10 +- .../wire/WireStreamingRequestBody.java | 45 ++++++ .../converter/wire/CrashingPhone.java | 129 ++++++++++++++++++ .../wire/WireConverterFactoryTest.java | 53 ++++++- 8 files changed, 259 insertions(+), 11 deletions(-) create mode 100644 retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java create mode 100644 retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 95cef91334..eb2d6154e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,10 @@ **New** - - Nothing yet! + - First-party converters now support deferring serialization to happen when the request body is written (i.e., during HTTP execution) rather than when the HTTP request is created. In some cases this moves conversion from a calling thread to a background thread, such as in the case when using `Call.enqueue` directly. + + The following converters support this feature through a new `createStreaming()` factory: + - Wire **Changed** diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 94841d99da..16998a3ab0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -83,3 +83,4 @@ robovm = { module = "com.mobidevelop.robovm:robovm-rt", version.ref = "robovm" } googleJavaFormat = "com.google.googlejavaformat:google-java-format:1.25.0" ktlint = "com.pinterest.ktlint:ktlint-cli:1.5.0" compileTesting = "com.google.testing.compile:compile-testing:0.21.0" +testParameterInjector = "com.google.testparameterinjector:test-parameter-injector:1.18" diff --git a/retrofit-converters/wire/build.gradle b/retrofit-converters/wire/build.gradle index 6f7c414a9c..06aef0cc81 100644 --- a/retrofit-converters/wire/build.gradle +++ b/retrofit-converters/wire/build.gradle @@ -12,6 +12,7 @@ dependencies { testImplementation libs.junit testImplementation libs.truth testImplementation libs.okhttp.mockwebserver + testImplementation libs.testParameterInjector } jar { diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java index f3472b5744..aaa4c26cf6 100644 --- a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import okhttp3.RequestBody; import okhttp3.ResponseBody; +import retrofit2.Call; import retrofit2.Converter; import retrofit2.Retrofit; @@ -31,11 +32,30 @@ *

This converter only applies for types which extend from {@link Message}. */ public final class WireConverterFactory extends Converter.Factory { + /** + * Create an instance which serializes request messages to bytes eagerly on the caller thread + * when either {@link Call#execute()} or {@link Call#enqueue} is called. Response bytes are + * always converted to message instances on one of OKHttp's background threads. + */ public static WireConverterFactory create() { - return new WireConverterFactory(); + return new WireConverterFactory(false); } - private WireConverterFactory() {} + /** + * Create an instance which streams serialization of request messages to bytes on the HTTP thread + * This is either the calling thread for {@link Call#execute()}, or one of OKHttp's background + * threads for {@link Call#enqueue}. Response bytes are always converted to message instances on + * one of OKHttp's background threads. + */ + public static WireConverterFactory createStreaming() { + return new WireConverterFactory(true); + } + + private final boolean streaming; + + private WireConverterFactory(boolean streaming) { + this.streaming = streaming; + } @Override public @Nullable Converter responseBodyConverter( @@ -67,6 +87,6 @@ private WireConverterFactory() {} } //noinspection unchecked ProtoAdapter adapter = ProtoAdapter.get((Class) c); - return new WireRequestBodyConverter<>(adapter); + return new WireRequestBodyConverter<>(adapter, streaming); } } diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java index 1d6be146f6..c1d96d3a02 100644 --- a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java @@ -24,16 +24,22 @@ import retrofit2.Converter; final class WireRequestBodyConverter> implements Converter { - private static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf"); + static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf"); private final ProtoAdapter adapter; + private final boolean streaming; - WireRequestBodyConverter(ProtoAdapter adapter) { + WireRequestBodyConverter(ProtoAdapter adapter, boolean streaming) { this.adapter = adapter; + this.streaming = streaming; } @Override public RequestBody convert(T value) throws IOException { + if (streaming) { + return new WireStreamingRequestBody<>(adapter, value); + } + Buffer buffer = new Buffer(); adapter.encode(buffer, value); return RequestBody.create(MEDIA_TYPE, buffer.snapshot()); diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java new file mode 100644 index 0000000000..f441ff4a6d --- /dev/null +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.converter.wire; + +import static retrofit2.converter.wire.WireRequestBodyConverter.MEDIA_TYPE; + +import com.squareup.wire.Message; +import com.squareup.wire.ProtoAdapter; +import java.io.IOException; +import okhttp3.MediaType; +import okhttp3.RequestBody; +import okio.BufferedSink; + +final class WireStreamingRequestBody> extends RequestBody { + private final ProtoAdapter adapter; + private final T value; + + WireStreamingRequestBody(ProtoAdapter adapter, T value) { + this.adapter = adapter; + this.value = value; + } + + @Override + public MediaType contentType() { + return MEDIA_TYPE; + } + + @Override + public void writeTo(BufferedSink sink) throws IOException { + adapter.encode(sink, value); + } +} diff --git a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java new file mode 100644 index 0000000000..f020475e37 --- /dev/null +++ b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java @@ -0,0 +1,129 @@ +// Code generated by Wire protocol buffer compiler, do not edit. +// Source file: phone.proto at 6:1 +package retrofit2.converter.wire; + +import com.squareup.wire.FieldEncoding; +import com.squareup.wire.Message; +import com.squareup.wire.ProtoAdapter; +import com.squareup.wire.ProtoReader; +import com.squareup.wire.ProtoWriter; +import com.squareup.wire.WireField; +import com.squareup.wire.internal.Internal; +import java.io.EOFException; +import java.io.IOException; +import okio.ByteString; + +public final class CrashingPhone extends Message { + public static final ProtoAdapter ADAPTER = new ProtoAdapter_CrashingPhone(); + + private static final long serialVersionUID = 0L; + + public static final String DEFAULT_NUMBER = ""; + + @WireField(tag = 1, adapter = "com.squareup.wire.ProtoAdapter#STRING") + public final String number; + + public CrashingPhone(String number) { + this(number, ByteString.EMPTY); + } + + public CrashingPhone(String number, ByteString unknownFields) { + super(ADAPTER, unknownFields); + this.number = number; + } + + @Override + public Builder newBuilder() { + Builder builder = new Builder(); + builder.number = number; + builder.addUnknownFields(unknownFields()); + return builder; + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if (!(other instanceof CrashingPhone)) return false; + CrashingPhone o = (CrashingPhone) other; + return Internal.equals(unknownFields(), o.unknownFields()) && Internal.equals(number, o.number); + } + + @Override + public int hashCode() { + int result = super.hashCode; + if (result == 0) { + result = unknownFields().hashCode(); + result = result * 37 + (number != null ? number.hashCode() : 0); + super.hashCode = result; + } + return result; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + if (number != null) builder.append(", number=").append(number); + return builder.replace(0, 2, "Phone{").append('}').toString(); + } + + public static final class Builder extends Message.Builder { + public String number; + + public Builder() {} + + public Builder number(String number) { + this.number = number; + return this; + } + + @Override + public CrashingPhone build() { + return new CrashingPhone(number, buildUnknownFields()); + } + } + + private static final class ProtoAdapter_CrashingPhone extends ProtoAdapter { + ProtoAdapter_CrashingPhone() { + super(FieldEncoding.LENGTH_DELIMITED, CrashingPhone.class); + } + + @Override + public int encodedSize(CrashingPhone value) { + return (value.number != null ? ProtoAdapter.STRING.encodedSizeWithTag(1, value.number) : 0) + + value.unknownFields().size(); + } + + @Override + public void encode(ProtoWriter writer, CrashingPhone value) throws IOException { + throw new EOFException("oops!"); + } + + @Override + public CrashingPhone decode(ProtoReader reader) throws IOException { + Builder builder = new Builder(); + long token = reader.beginMessage(); + for (int tag; (tag = reader.nextTag()) != -1; ) { + switch (tag) { + case 1: + builder.number(ProtoAdapter.STRING.decode(reader)); + break; + default: + { + FieldEncoding fieldEncoding = reader.peekFieldEncoding(); + Object value = fieldEncoding.rawProtoAdapter().decode(reader); + builder.addUnknownField(tag, fieldEncoding, value); + } + } + } + reader.endMessage(token); + return builder.build(); + } + + @Override + public CrashingPhone redact(CrashingPhone value) { + Builder builder = value.newBuilder(); + builder.clearUnknownFields(); + return builder.build(); + } + } +} diff --git a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java index ff6c73748f..443a9dee91 100644 --- a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java +++ b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java @@ -17,25 +17,32 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import java.io.EOFException; import java.io.IOException; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import okio.Buffer; import okio.ByteString; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.runner.RunWith; import retrofit2.Call; +import retrofit2.Callback; import retrofit2.Response; import retrofit2.Retrofit; import retrofit2.http.Body; import retrofit2.http.GET; import retrofit2.http.POST; +@RunWith(TestParameterInjector.class) public final class WireConverterFactoryTest { interface Service { @GET("/") @@ -44,6 +51,9 @@ interface Service { @POST("/") Call post(@Body Phone impl); + @POST("/") + Call postCrashing(@Body CrashingPhone impl); + @GET("/") Call wrongClass(); @@ -53,14 +63,17 @@ interface Service { @Rule public final MockWebServer server = new MockWebServer(); - private Service service; + private final Service service; + private final boolean streaming; + + public WireConverterFactoryTest(@TestParameter boolean streaming) { + this.streaming = streaming; - @Before - public void setUp() { Retrofit retrofit = new Retrofit.Builder() .baseUrl(server.url("/")) - .addConverterFactory(WireConverterFactory.create()) + .addConverterFactory( + streaming ? WireConverterFactory.createStreaming() : WireConverterFactory.create()) .build(); service = retrofit.create(Service.class); } @@ -80,6 +93,36 @@ public void serializeAndDeserialize() throws IOException, InterruptedException { assertThat(request.getHeader("Content-Type")).isEqualTo("application/x-protobuf"); } + @Test + public void serializeIsStreamed() throws IOException, InterruptedException { + assumeTrue(streaming); + + Call call = service.postCrashing(new CrashingPhone("(519) 867-5309")); + + final AtomicReference throwableRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + // If streaming were broken, the call to enqueue would throw the exception synchronously. + call.enqueue( + new Callback() { + @Override + public void onResponse(Call call, Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Call call, Throwable t) { + throwableRef.set(t); + latch.countDown(); + } + }); + latch.await(); + + Throwable throwable = throwableRef.get(); + assertThat(throwable).isInstanceOf(EOFException.class); + assertThat(throwable).hasMessageThat().isEqualTo("oops!"); + } + @Test public void deserializeEmpty() throws IOException { server.enqueue(new MockResponse());