(value.length);
+ SizeTPointer lengths = new SizeTPointer(value.length);
- private static native void setAttrStringList(long handle, String name, Object[] value);
+ for (int i = 0; i < value.length; ++i) {
+ valuePointers.put(i, new BytePointer(value[i]));
+ lengths.put(i, value[i].length);
+ }
+ TF_SetAttrStringList(handle, new BytePointer(name), valuePointers, lengths, value.length);
+ }
+ }
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java
index b298f57d9fb..2e10a22f89e 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java
@@ -16,6 +16,7 @@
package org.tensorflow;
import java.util.Objects;
+import org.bytedeco.javacpp.Pointer;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.family.TType;
@@ -105,7 +106,7 @@ public String toString() {
index = idx;
}
- long getUnsafeNativeHandle() {
+ Pointer getUnsafeNativeHandle() {
return operation.getUnsafeNativeHandle(index);
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
index e653373f856..4853d483494 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -15,6 +15,19 @@
package org.tensorflow;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
+
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.PointerPointer;
+import org.bytedeco.javacpp.PointerScope;
+import org.tensorflow.internal.c_api.TF_Buffer;
+import org.tensorflow.internal.c_api.TF_Graph;
+import org.tensorflow.internal.c_api.TF_Session;
+import org.tensorflow.internal.c_api.TF_SessionOptions;
+import org.tensorflow.internal.c_api.TF_Status;
+
/**
* SavedModelBundle represents a model loaded from storage.
*
@@ -157,14 +170,41 @@ private SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef) {
* Invoked from the native load method. Takes ownership of the handles.
*/
private static SavedModelBundle fromHandle(
- long graphHandle, long sessionHandle, byte[] metaGraphDef) {
+ TF_Graph graphHandle, TF_Session sessionHandle, byte[] metaGraphDef) {
Graph graph = new Graph(graphHandle);
Session session = new Session(graph, sessionHandle);
return new SavedModelBundle(graph, session, metaGraphDef);
}
- private static native SavedModelBundle load(
- String exportDir, String[] tags, byte[] config, byte[] runOptions);
+ private static SavedModelBundle load(
+ String exportDir, String[] tags, byte[] config, byte[] runOptions) {
+ SavedModelBundle bundle = null;
+
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+
+ // allocate parameters for TF_LoadSessionFromSavedModel
+ TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
+ if (config != null && config.length > 0) {
+ TF_SetConfig(opts, new BytePointer(config), config.length, status);
+ status.throwExceptionIfNotOK();
+ }
+ TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
+
+ // load the session
+ TF_Graph graph = TF_NewGraph();
+ TF_Buffer metagraphDef = TF_Buffer.newBuffer();
+ TF_Session session = TF_LoadSessionFromSavedModel(
+ opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags),
+ tags.length, graph, metagraphDef, status);
+ status.throwExceptionIfNotOK();
+
+ // handle the result
+ bundle = fromHandle(graph, session, metagraphDef.get());
+ }
+
+ return bundle;
+ }
static {
TensorFlow.init();
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java
index 6adcdba17b3..9228f93e716 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java
@@ -15,6 +15,17 @@
package org.tensorflow;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteServer;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewServer;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_ServerJoin;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_ServerStart;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_ServerStop;
+
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.PointerScope;
+import org.tensorflow.internal.c_api.TF_Server;
+import org.tensorflow.internal.c_api.TF_Status;
+
/**
* An in-process TensorFlow server, for use in distributed training.
*
@@ -83,10 +94,10 @@ public synchronized void stop() {
/** Blocks until the server has been successfully stopped. */
public void join() {
- long handle = 0;
+ TF_Server handle = null;
synchronized (this) {
handle = nativeHandle;
- if (handle != 0) {
+ if (handle != null && !handle.isNull()) {
numJoining++;
}
}
@@ -94,7 +105,7 @@ public void join() {
join(handle);
} finally {
synchronized (this) {
- if (handle != 0) {
+ if (handle != null && !handle.isNull()) {
numJoining--;
}
notifyAll();
@@ -110,20 +121,57 @@ public synchronized void close() throws InterruptedException {
wait();
}
delete(nativeHandle);
- nativeHandle = 0;
+ nativeHandle = null;
}
- private static native long allocate(byte[] serverDef);
+ private static void requireHandle(TF_Server handle) {
+ if (handle == null || handle.isNull()) {
+ throw new IllegalStateException("close() has been called on the Server");
+ }
+ }
+
+ private static TF_Server allocate(byte[] serverDef) {
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_Server server = TF_NewServer(new BytePointer(serverDef), serverDef.length, status);
+ status.throwExceptionIfNotOK();
+ return server;
+ }
+ }
- private static native void start(long nativeHandle);
+ private static void start(TF_Server nativeHandle) {
+ requireHandle(nativeHandle);
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_ServerStart(nativeHandle, status);
+ status.throwExceptionIfNotOK();
+ }
+ }
- private static native void stop(long nativeHandle);
+ private static void stop(TF_Server nativeHandle) {
+ requireHandle(nativeHandle);
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_ServerStop(nativeHandle, status);
+ status.throwExceptionIfNotOK();
+ }
+ }
- private static native void join(long nativeHandle);
+ private static void join(TF_Server nativeHandle) {
+ requireHandle(nativeHandle);
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_ServerJoin(nativeHandle, status);
+ status.throwExceptionIfNotOK();
+ }
+ }
- private static native void delete(long nativeHandle);
+ private static void delete(TF_Server nativeHandle) {
+ requireHandle(nativeHandle);
+ TF_DeleteServer(nativeHandle);
+ }
- private long nativeHandle;
+ private TF_Server nativeHandle;
private int numJoining;
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
index e1b236074e4..0af91f432e0 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
@@ -15,8 +15,27 @@
package org.tensorflow;
+import static org.tensorflow.Graph.resolveOutputs;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
+
import java.util.ArrayList;
import java.util.List;
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.PointerPointer;
+import org.bytedeco.javacpp.PointerScope;
+import org.tensorflow.internal.c_api.TF_Buffer;
+import org.tensorflow.internal.c_api.TF_Graph;
+import org.tensorflow.internal.c_api.TF_Operation;
+import org.tensorflow.internal.c_api.TF_Output;
+import org.tensorflow.internal.c_api.TF_Session;
+import org.tensorflow.internal.c_api.TF_SessionOptions;
+import org.tensorflow.internal.c_api.TF_Status;
+import org.tensorflow.internal.c_api.TF_Tensor;
/**
* Driver for {@link Graph} execution.
@@ -49,7 +68,7 @@ public final class Session implements AutoCloseable {
/** Construct a new session with the associated {@link Graph}. */
public Session(Graph g) {
- this(g, null);
+ this(g, (byte[])null);
}
/**
@@ -75,7 +94,7 @@ public Session(Graph g, byte[] config) {
}
/** Wrap an existing session with the associated {@link Graph}. */
- Session(Graph g, long nativeHandle) {
+ Session(Graph g, TF_Session nativeHandle) {
graph = g;
this.nativeHandle = nativeHandle;
graphRef = g.ref();
@@ -91,7 +110,7 @@ public Session(Graph g, byte[] config) {
public void close() {
graphRef.close();
synchronized (nativeHandleLock) {
- if (nativeHandle == 0) {
+ if (nativeHandle == null || nativeHandle.isNull()) {
return;
}
while (numActiveRuns > 0) {
@@ -104,7 +123,7 @@ public void close() {
}
}
delete(nativeHandle);
- nativeHandle = 0;
+ nativeHandle = null;
}
}
@@ -289,13 +308,13 @@ public Run runAndFetchMetadata() {
}
private Run runHelper(boolean wantMetadata) {
- long[] inputTensorHandles = new long[inputTensors.size()];
- long[] inputOpHandles = new long[inputs.size()];
+ TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()];
+ TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()];
int[] inputOpIndices = new int[inputs.size()];
- long[] outputOpHandles = new long[outputs.size()];
+ TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()];
int[] outputOpIndices = new int[outputs.size()];
- long[] targetOpHandles = new long[targets.size()];
- long[] outputTensorHandles = new long[outputs.size()];
+ TF_Operation[] targetOpHandles = new TF_Operation[targets.size()];
+ TF_Tensor[] outputTensorHandles = new TF_Tensor[outputs.size()];
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
// validity of the Graph and graphRef ensures that.
@@ -305,13 +324,13 @@ private Run runHelper(boolean wantMetadata) {
}
idx = 0;
for (Output> o : inputs) {
- inputOpHandles[idx] = o.getUnsafeNativeHandle();
+ inputOpHandles[idx] = (TF_Operation)o.getUnsafeNativeHandle();
inputOpIndices[idx] = o.index();
idx++;
}
idx = 0;
for (Output> o : outputs) {
- outputOpHandles[idx] = o.getUnsafeNativeHandle();
+ outputOpHandles[idx] = (TF_Operation)o.getUnsafeNativeHandle();
outputOpIndices[idx] = o.index();
idx++;
}
@@ -338,7 +357,7 @@ private Run runHelper(boolean wantMetadata) {
runRef.close();
}
List> outputs = new ArrayList<>();
- for (long h : outputTensorHandles) {
+ for (TF_Tensor h : outputTensorHandles) {
try {
outputs.add(Tensor.fromHandle(h));
} catch (Exception e) {
@@ -358,7 +377,7 @@ private Run runHelper(boolean wantMetadata) {
private class Reference implements AutoCloseable {
public Reference() {
synchronized (nativeHandleLock) {
- if (nativeHandle == 0) {
+ if (nativeHandle == null || nativeHandle.isNull()) {
throw new IllegalStateException("run() cannot be called on the Session after close()");
}
++numActiveRuns;
@@ -368,7 +387,7 @@ public Reference() {
@Override
public void close() {
synchronized (nativeHandleLock) {
- if (nativeHandle == 0) {
+ if (nativeHandle == null || nativeHandle.isNull()) {
return;
}
if (--numActiveRuns == 0) {
@@ -440,15 +459,63 @@ public static final class Run {
private final Graph.Reference graphRef;
private final Object nativeHandleLock = new Object();
- private long nativeHandle;
+ private TF_Session nativeHandle;
private int numActiveRuns;
+ private static void requireHandle(Pointer handle) {
+ if (handle == null || handle.isNull()) {
+ throw new IllegalStateException("close() has been called on the Session");
+ }
+ }
+
+ private static void resolveHandles(String type, Pointer[] src, PointerPointer dst, int n) {
+ if (src.length != n) {
+ throw new IllegalArgumentException("expected " + n + ", got " + src.length + " " + type);
+ }
+ for (int i = 0; i < n; ++i) {
+ if (src[i] == null || src[i].isNull()) {
+ throw new IllegalStateException("invalid " + type + " (#" + i + " of " + n + ")");
+ }
+ dst.put(i, src[i]);
+ }
+ }
+
// TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2().
- private static native long allocate(long graphHandle);
+ private static TF_Session allocate(TF_Graph graphHandle) {
+ return allocate2(graphHandle, null, null);
+ }
+
+ private static TF_Session allocate2(TF_Graph graphHandle, String target, byte[] config) {
+ if (graphHandle == null || graphHandle.isNull()) {
+ throw new IllegalStateException("Graph has been close()d");
+ }
- private static native long allocate2(long graphHandle, String target, byte[] config);
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
+ if (config != null && config.length > 0) {
+ TF_SetConfig(opts, new BytePointer(config), config.length, status);
+ status.throwExceptionIfNotOK();
+ }
+
+ TF_Session session = TF_NewSession(graphHandle, opts, status);
+ status.throwExceptionIfNotOK();
+
+ return session;
+ }
+ }
+
+ private static void delete(TF_Session handle) {
+ requireHandle(handle);
- private static native void delete(long handle);
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_CloseSession(handle, status);
+ // Result of close is ignored, delete anyway.
+ TF_DeleteSession(handle, status);
+ status.throwExceptionIfNotOK();
+ }
+ }
/**
* Execute a session.
@@ -477,15 +544,49 @@ public static final class Run {
* @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol
* buffer, false otherwise.
*/
- private static native byte[] run(
- long handle,
+ private static byte[] run(
+ TF_Session handle,
byte[] runOptions,
- long[] inputTensorHandles,
- long[] inputOpHandles,
+ TF_Tensor[] inputTensorHandles,
+ TF_Operation[] inputOpHandles,
int[] inputOpIndices,
- long[] outputOpHandles,
+ TF_Operation[] outputOpHandles,
int[] outputOpIndices,
- long[] targetOpHandles,
+ TF_Operation[] targetOpHandles,
boolean wantRunMetadata,
- long[] outputTensorHandles);
+ TF_Tensor[] outputTensorHandles) {
+ requireHandle(handle);
+
+ int ninputs = inputTensorHandles.length;
+ int noutputs = outputTensorHandles.length;
+ int ntargets = targetOpHandles.length;
+
+ try (PointerScope scope = new PointerScope()) {
+ TF_Output inputs = new TF_Output(ninputs);
+ PointerPointer inputValues = new PointerPointer(ninputs);
+ TF_Output outputs = new TF_Output(noutputs);
+ PointerPointer outputValues = new PointerPointer(noutputs);
+ PointerPointer targets = new PointerPointer(ntargets);
+ TF_Buffer runMetadata = wantRunMetadata ? TF_Buffer.newBuffer() : null;
+
+ resolveHandles("input Tensors", inputTensorHandles, inputValues, ninputs);
+ resolveOutputs("input", inputOpHandles, inputOpIndices, inputs, ninputs);
+ resolveOutputs("output", outputOpHandles, outputOpIndices, outputs, noutputs);
+ resolveHandles("target Operations", targetOpHandles, targets, ntargets);
+
+ TF_Status status = TF_Status.newStatus();
+ TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
+
+ TF_SessionRun(handle, runOpts, inputs, inputValues, ninputs,
+ outputs, outputValues, noutputs, targets, ntargets,
+ runMetadata, status);
+ status.throwExceptionIfNotOK();
+
+ for (int i = 0; i < noutputs; ++i) {
+ outputTensorHandles[i] = outputValues.get(TF_Tensor.class, i);
+ }
+
+ return runMetadata != null ? runMetadata.get() : null;
+ }
+ }
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java
index ced652a77c0..57026923d5d 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java
@@ -15,6 +15,29 @@
package org.tensorflow;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_AllocateTensor;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_BOOL;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DOUBLE;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteTensor;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_FLOAT;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetCode;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_INT32;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_INT64;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_INTERNAL;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_OK;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_OUT_OF_RANGE;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_STRING;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetStatus;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_StringDecode;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_StringEncode;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_StringEncodedSize;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorData;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_UINT8;
+
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.BufferOverflowException;
@@ -27,8 +50,18 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.function.Consumer;
+import org.bytedeco.javacpp.BooleanPointer;
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.DoublePointer;
+import org.bytedeco.javacpp.FloatPointer;
+import org.bytedeco.javacpp.IntPointer;
+import org.bytedeco.javacpp.Loader;
+import org.bytedeco.javacpp.LongPointer;
+import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.PointerScope;
+import org.bytedeco.javacpp.SizeTPointer;
+import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
-import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat64;
@@ -126,7 +159,7 @@ public static Tensor create(Object obj, DataType dtype)
long[] dimSizes = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, dimSizes);
Tensor t = new Tensor(dtype, Shape.make(dimSizes));
- long nativeHandle;
+ TF_Tensor nativeHandle;
if (t.dtype != TString.DTYPE) {
long byteSize = elemByteSize(t.dtype) * t.shape.size();
nativeHandle = allocate(t.dtype.nativeCode(), dimSizes, byteSize);
@@ -263,7 +296,7 @@ public static Tensor allocate(DataType dtype, Shape shap
public static Tensor allocate(DataType dtype, Shape shape, long size) {
Tensor t = new Tensor<>(dtype, shape);
- long nativeHandle = allocate(t.dtype.nativeCode(), shape.asArray(), size);
+ TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), shape.asArray(), size);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
@@ -318,7 +351,7 @@ private static Tensor allocateForBuffer(DataType dataTyp
nbytes = nBuffered;
}
Tensor t = new Tensor<>(dataType, Shape.make(dimSizes));
- long nativeHandle = allocate(t.dtype.nativeCode(), dimSizes, nbytes);
+ TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), dimSizes, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
@@ -352,7 +385,7 @@ public T data() {
/** Returns the size, in bytes, of the tensor data. */
public long numBytes() {
if (numBytes == null) {
- numBytes = tensorflow.TF_TensorByteSize(nativeRef.cTensor);
+ numBytes = TF_TensorByteSize(nativeRef.tensorHandle);
}
return numBytes;
}
@@ -548,7 +581,7 @@ public String toString() {
*
* Takes ownership of the handle.
*/
- static Tensor> fromHandle(long handle) {
+ static Tensor> fromHandle(TF_Tensor handle) {
Tensor> t = new Tensor<>(DataTypes.fromNativeCode(dtype(handle)), Shape.make(shape(handle)));
t.nativeRef = new NativeReference(handle);
return t;
@@ -559,20 +592,16 @@ static Tensor> fromHandle(long handle) {
*
*
Takes ownership of the handle.
*/
- static Tensor> fromHandle(long handle, EagerSession session) {
+ static Tensor> fromHandle(TF_Tensor handle, EagerSession session) {
Tensor> t = fromHandle(handle);
t.nativeRef.eager(session, t);
return t;
}
- long getNativeHandle() {
+ TF_Tensor getNativeHandle() {
return nativeRef.tensorHandle;
}
- TF_Tensor getNative() {
- return nativeRef.cTensor;
- }
-
private NativeReference nativeRef = null;
private final DataType dtype;
private final Shape shape;
@@ -649,7 +678,7 @@ void delete() {
}
}
- NativeReference(long tensorHandle) {
+ NativeReference(TF_Tensor tensorHandle) {
setTensorHandle(tensorHandle);
}
@@ -661,24 +690,22 @@ void eager(EagerSession session, Tensor> tensor) {
}
synchronized void release() {
- if (tensorHandle != 0L) {
+ if (tensorHandle != null && !tensorHandle.isNull()) {
// Clear any remaining eager reference to this tensor
if (eagerRef != null) {
eagerRef.clear();
eagerRef = null;
}
Tensor.delete(tensorHandle);
- setTensorHandle(0L);
+ setTensorHandle(null);
}
}
- private long tensorHandle;
- private final TF_Tensor cTensor = new TF_Tensor();
+ private TF_Tensor tensorHandle;
private EagerReference eagerRef;
- private void setTensorHandle(long tensorHandle) {
+ private void setTensorHandle(TF_Tensor tensorHandle) {
this.tensorHandle = tensorHandle;
- cTensor.temporaryHackToSetAddressFromHandle(tensorHandle);
}
}
@@ -821,35 +848,503 @@ private void throwExceptionIfTypeIsIncompatible(Object o) {
}
}
- private static native long allocate(int dtype, long[] shape, long byteSize);
+ private static void requireHandle(TF_Tensor handle) {
+ if (handle == null || handle.isNull()) {
+ throw new IllegalStateException("close() was called on the Tensor");
+ }
+ }
+
+ private static int elemByteSize(int dtype) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (dtype) {
+ case TF_BOOL:
+ case TF_UINT8:
+ assert Loader.sizeof(BooleanPointer.class) == 1 :
+ "Java boolean not compatible with TF_BOOL";
+ assert Loader.sizeof(BytePointer.class) == 1 :
+ "Java byte not compatible with TF_UINT8";
+ return 1;
+ case TF_FLOAT:
+ case TF_INT32:
+ assert Loader.sizeof(FloatPointer.class) == 4 :
+ "Java float not compatible with TF_FLOAT";
+ assert Loader.sizeof(IntPointer.class) == 4 :
+ "Java int not compatible with TF_INT32";
+ return 4;
+ case TF_DOUBLE:
+ case TF_INT64:
+ assert Loader.sizeof(DoublePointer.class) == 8 :
+ "Java double not compatible with TF_DOUBLE";
+ assert Loader.sizeof(LongPointer.class) == 8 :
+ "Java long not compatible with TF_INT64";
+ return 8;
+ default:
+ return 0;
+ }
+ }
+
+ /** Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor. */
+ private static void writeScalar(Object src, int dtype, BytePointer dst, long dstSize) {
+ int sz = elemByteSize(dtype);
+ if (sz != dstSize) {
+ throw new IllegalStateException("scalar (" + sz
+ + " bytes) not compatible with allocated tensor (" + dstSize + " bytes)");
+ }
+ switch (dtype) {
+ case TF_FLOAT:
+ dst.putFloat((Float)src);
+ break;
+ case TF_DOUBLE:
+ dst.putDouble((Double)src);
+ break;
+ case TF_INT32:
+ dst.putInt((Integer)src);
+ break;
+ case TF_INT64:
+ dst.putLong((Long)src);
+ break;
+ case TF_UINT8:
+ dst.put((Byte)src);
+ break;
+ case TF_BOOL:
+ dst.putBool((Boolean)src);
+ break;
+ default:
+ throw new IllegalStateException("invalid DataType(" + dtype + ")");
+ }
+ }
+
+ private static int getArrayLength(Object array, int dtype) {
+ switch (dtype) {
+ case TF_FLOAT: return ((float[])array).length;
+ case TF_DOUBLE: return ((double[])array).length;
+ case TF_INT32: return ((int[])array).length;
+ case TF_INT64: return ((long[])array).length;
+ case TF_UINT8: return ((byte[])array).length;
+ case TF_BOOL: return ((boolean[])array).length;
+ default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
+ }
+ }
+
+ /** Copy a 1-D array of Java primitive types to the tensor buffer dst.
+ * Returns the number of bytes written to dst. */
+ private static long write1DArray(Object array, int dtype, BytePointer dst, long dstSize) {
+ int nelems = getArrayLength(array, dtype);
+ long toCopy = nelems * elemByteSize(dtype);
+ if (toCopy > dstSize) {
+ throw new IllegalStateException(
+ "cannot write Java array of " + toCopy + " bytes to Tensor of " + dstSize + " bytes");
+ }
+ switch (dtype) {
+ case TF_FLOAT:
+ new FloatPointer(dst).position(dst.position() / 4).put((float[])array);
+ break;
+ case TF_DOUBLE:
+ new DoublePointer(dst).position(dst.position() / 8).put((double[])array);
+ break;
+ case TF_INT32:
+ new IntPointer(dst).position(dst.position() / 4).put((int[])array);
+ break;
+ case TF_INT64:
+ new LongPointer(dst).position(dst.position() / 8).put((long[])array);
+ break;
+ case TF_UINT8:
+ dst.put((byte[])array);
+ break;
+ case TF_BOOL:
+ new BooleanPointer(dst).position(dst.position()).put((boolean[])array);
+ break;
+ default:
+ throw new IllegalStateException("invalid DataType(" + dtype + ")");
+ }
+ return toCopy;
+ }
+
+ /** Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
+ * Java primitive types. Returns the number of bytes read from src. */
+ private static long read1DArray(int dtype, BytePointer src, long srcSize, Object dst) {
+ int len = getArrayLength(dst, dtype);
+ long sz = len * elemByteSize(dtype);
+ if (sz > srcSize) {
+ throw new IllegalStateException(
+ "cannot fill a Java array of " + sz + "bytes with a Tensor of " + srcSize + " bytes");
+ }
+ switch (dtype) {
+ case TF_FLOAT:
+ new FloatPointer(src).position(src.position() / 4).get((float[])dst);
+ break;
+ case TF_DOUBLE:
+ new DoublePointer(src).position(src.position() / 8).get((double[])dst);
+ break;
+ case TF_INT32:
+ new IntPointer(src).position(src.position() / 4).get((int[])dst);
+ break;
+ case TF_INT64:
+ new LongPointer(src).position(src.position() / 8).get((long[])dst);
+ break;
+ case TF_UINT8:
+ src.get((byte[])dst);
+ break;
+ case TF_BOOL:
+ new BooleanPointer(src).position(src.position()).get((boolean[])dst);
+ break;
+ default:
+ throw new IllegalStateException("invalid DataType(" + dtype + ")");
+ }
+ return sz;
+ }
+
+ private static long writeNDArray(Object src, int dtype, int dimsLeft,
+ BytePointer dst, long dstSize) {
+ if (dimsLeft == 1) {
+ return write1DArray(src, dtype, dst, dstSize);
+ }
+ Object[] ndarray = (Object[])src;
+ long sz = 0;
+ for (int i = 0; i < ndarray.length; ++i) {
+ Object row = ndarray[i];
+ sz += writeNDArray(row, dtype, dimsLeft - 1,
+ new BytePointer(dst).position(dst.position() + sz), dstSize - sz);
+ }
+ return sz;
+ }
+
+ private static long readNDArray(int dtype, BytePointer src, long srcSize,
+ int dimsLeft, Object dst) {
+ if (dimsLeft == 1) {
+ return read1DArray(dtype, src, srcSize, dst);
+ }
+ Object[] ndarray = (Object[])dst;
+ long sz = 0;
+ for (int i = 0; i < ndarray.length; ++i) {
+ Object row = ndarray[i];
+ sz += readNDArray(dtype, new BytePointer(src).position(src.position() + sz),
+ srcSize - sz, dimsLeft - 1, row);
+ }
+ return sz;
+ }
+
+ private static byte[] TF_StringDecodeToArray(BytePointer src, long srcLen, TF_Status status) {
+ try (PointerScope scope = new PointerScope()) {
+ BytePointer dst = new BytePointer((Pointer)null);
+ SizeTPointer dstLen = new SizeTPointer(1);
+ TF_StringDecode(src, srcLen, dst, dstLen, status);
+ if (TF_GetCode(status) != TF_OK) {
+ return null;
+ }
+ byte[] ret = new byte[(int)dstLen.get()];
+ dst.get(ret);
+ return ret;
+ }
+ }
+
+ private static class StringTensorWriter {
+ StringTensorWriter(TF_Tensor t, long numElements) {
+ offset = 0;
+ poffsets = new BytePointer(TF_TensorData(t));
+ pdata = new BytePointer(poffsets).position(8 * numElements);
+ plimit = new BytePointer(poffsets).position(TF_TensorByteSize(t));
+ }
+
+ void Add(BytePointer src, long len, TF_Status status) {
+ if (TF_GetCode(status) != TF_OK) return;
+ if (plimit.position() - poffsets.position() < 8) {
+ TF_SetStatus(status, TF_OUT_OF_RANGE,
+ "TF_STRING tensor encoding ran out of space for offsets, "
+ + "this is likely a bug, please file an issue at "
+ + "https://github.com/tensorflow/java/issues/new");
+ return;
+ }
+ poffsets.putLong(offset);
+ long written =
+ TF_StringEncode(src, len, pdata, plimit.position() - pdata.position(), status);
+ offset += written;
+ poffsets.position(poffsets.position() + 8);
+ pdata.position(pdata.position() + written);
+ }
+
+ long offset;
+ BytePointer poffsets;
+ BytePointer pdata;
+ BytePointer plimit;
+ }
+
+ private static class StringTensorReader {
+ StringTensorReader(TF_Tensor t, long numElements) {
+ index = 0;
+ offsets = new BytePointer(TF_TensorData(t));
+ data = new BytePointer(offsets).position(8 * numElements);
+ limit = new BytePointer(offsets).position(TF_TensorByteSize(t));
+ }
+
+ byte[] Next(TF_Status status) {
+ if (TF_GetCode(status) != TF_OK) return null;
+ long offset = 0;
+ BytePointer poffset = new BytePointer(offsets).position(8 * index);
+ if (poffset.position() >= limit.position()) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Invalid TF_STRING tensor, offsets table seems to be too small");
+ return null;
+ }
+ offset = poffset.getLong();
+ BytePointer pdata = new BytePointer(data).position(data.position() + offset);
+ if (pdata.position() >= limit.position()) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Invalid TF_STRING tensor, invalid entry in offset table");
+ return null;
+ }
+ ++index;
+ return TF_StringDecodeToArray(pdata, limit.position() - pdata.position(), status);
+ }
+
+ int index;
+ BytePointer offsets;
+ BytePointer data;
+ BytePointer limit;
+ }
+
+ private static void readNDStringArray(StringTensorReader reader, int dimsLeft,
+ Object[] dst, TF_Status status) {
+ if (dimsLeft == 1) {
+ for (int i = 0; i < dst.length; ++i) {
+ byte[] elem = reader.Next(status);
+ if (TF_GetCode(status) != TF_OK) return;
+ dst[i] = elem;
+ }
+ return;
+ }
+ for (int i = 0; i < dst.length; ++i) {
+ readNDStringArray(reader, dimsLeft - 1, (Object[])dst[i], status);
+ if (TF_GetCode(status) != TF_OK) return;
+ }
+ }
+
+ private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) {
+ TF_Tensor t = TF_AllocateTensor(dtype, shape, shape.length, byteSize);
+ if (t == null || t.isNull()) {
+ throw new IllegalStateException("unable to allocate memory for the Tensor");
+ }
+ return t;
+ }
- private static native long allocateScalarBytes(byte[] value);
+ private static TF_Tensor allocateScalarBytes(byte[] value) {
+ // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
+ // TF_StringEncode-encoded bytes.
+ long dstLen = TF_StringEncodedSize(value.length);
+ TF_Tensor t = TF_AllocateTensor(TF_STRING, (long[])null, 0, 8 + dstLen);
+ BytePointer dst = new BytePointer(TF_TensorData(t));
+ dst.putLong(0); // The offset table
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_StringEncode(new BytePointer(value), value.length, dst.position(8), dstLen, status);
+ status.throwExceptionIfNotOK();
+ return t;
+ }
+ }
- private static native long allocateNonScalarBytes(long[] shape, Object[] value);
+ private static long nonScalarStringTensorSize(Object value, int numDims) {
+ if (numDims == 0) {
+ // This is the last dimension, i.e., value should correspond to a jbyteArray
+ // encoding the string.
+ return TF_StringEncodedSize(((byte[])value).length);
+ }
+ Object[] array = (Object[])value;
+ long ret = 0;
+ for (int i = 0; i < array.length; ++i) {
+ Object elem = array[i];
+ if (elem == null) {
+ throw new IllegalStateException("null entries in provided array");
+ }
+ ret += nonScalarStringTensorSize(elem, numDims - 1);
+ }
+ return ret;
+ }
- private static native void delete(long handle);
+ private static void fillNonScalarStringTensorData(Object value, int numDims,
+ StringTensorWriter writer, TF_Status status) {
+ if (numDims == 0) {
+ byte[] src = (byte[])value;
+ writer.Add(new BytePointer(src), src.length, status);
+ return;
+ }
+ Object[] array = (Object[])value;
+ for (int i = 0; i < array.length; ++i) {
+ Object elem = array[i];
+ if (elem == null) {
+ throw new IllegalStateException("null entries in provided array");
+ }
+ fillNonScalarStringTensorData(elem, numDims - 1, writer, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ }
+ }
- private static native ByteBuffer buffer(long handle);
+ private static TF_Tensor allocateNonScalarBytes(long[] shape, Object[] value) {
+ // TF_STRING tensors are encoded with a table of 8-byte offsets following by
+ // TF_StringEncode-encoded bytes.
+ int numDims = shape.length;
+ long numElements = 1;
+ for (int i = 0; i < numDims; ++i) {
+ numElements *= shape[i];
+ }
+ long encodedSize = nonScalarStringTensorSize(value, numDims);
+ TF_Tensor t = TF_AllocateTensor(TF_STRING, shape, numDims,
+ 8 * numElements + encodedSize);
+ if (t == null || t.isNull()) {
+ throw new IllegalStateException("unable to allocate memory for the Tensor");
+ }
+ TF_Status status = TF_Status.newStatus();
+ try (PointerScope scope = new PointerScope()) {
+ StringTensorWriter writer = new StringTensorWriter(t, numElements);
+ fillNonScalarStringTensorData(value, numDims, writer, status);
+ status.throwExceptionIfNotOK();
+ return t;
+ }
+ }
- private static native int dtype(long handle);
+ private static void delete(TF_Tensor handle) {
+ if (handle == null || handle.isNull()) return;
+ TF_DeleteTensor(handle);
+ }
- private static native long[] shape(long handle);
+ private static ByteBuffer buffer(TF_Tensor handle) {
+ requireHandle(handle);
+ return TF_TensorData(handle).capacity(TF_TensorByteSize(handle)).asByteBuffer();
+ }
- private static native void setValue(long handle, Object value);
+ private static int dtype(TF_Tensor handle) {
+ requireHandle(handle);
+ return TF_TensorType(handle);
+ }
- private static native float scalarFloat(long handle);
+ private static long[] shape(TF_Tensor handle) {
+ requireHandle(handle);
+ int numDims = TF_NumDims(handle);
+ long[] dims = new long[numDims];
+ for (int i = 0; i < numDims; ++i) {
+ dims[i] = TF_Dim(handle, i);
+ }
+ return dims;
+ }
- private static native double scalarDouble(long handle);
+ private static void setValue(TF_Tensor handle, Object value) {
+ requireHandle(handle);
+ int numDims = TF_NumDims(handle);
+ int dtype = TF_TensorType(handle);
+ BytePointer data = new BytePointer(TF_TensorData(handle));
+ long sz = TF_TensorByteSize(handle);
+ if (numDims == 0) {
+ writeScalar(value, dtype, data, sz);
+ } else {
+ writeNDArray(value, dtype, numDims, data, sz);
+ }
+ }
- private static native int scalarInt(long handle);
+ private static float scalarFloat(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_FLOAT) {
+ throw new IllegalStateException("Tensor is not a float scalar");
+ }
+ return new FloatPointer(TF_TensorData(handle)).get();
+ }
+
+ private static double scalarDouble(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_DOUBLE) {
+ throw new IllegalStateException("Tensor is not a double scalar");
+ }
+ return new DoublePointer(TF_TensorData(handle)).get();
+ }
+
+ private static int scalarInt(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_INT32) {
+ throw new IllegalStateException("Tensor is not a int scalar");
+ }
+ return new IntPointer(TF_TensorData(handle)).get();
+ }
- private static native long scalarLong(long handle);
+ private static long scalarLong(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_INT64) {
+ throw new IllegalStateException("Tensor is not a long scalar");
+ }
+ return new LongPointer(TF_TensorData(handle)).get();
+ }
- private static native boolean scalarBoolean(long handle);
+ private static boolean scalarBoolean(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_BOOL) {
+ throw new IllegalStateException("Tensor is not a boolean scalar");
+ }
+ return new BooleanPointer(TF_TensorData(handle)).get();
+ }
- private static native byte[] scalarBytes(long handle);
+ private static byte[] scalarBytes(TF_Tensor handle) {
+ requireHandle(handle);
+ if (TF_NumDims(handle) != 0) {
+ throw new IllegalStateException("Tensor is not a scalar");
+ }
+ if (TF_TensorType(handle) != TF_STRING) {
+ throw new IllegalArgumentException("Tensor is not a string/bytes scalar");
+ }
+ BytePointer data = new BytePointer(TF_TensorData(handle));
+ BytePointer src = new BytePointer(data).position(8);
+ long srcLen = TF_TensorByteSize(handle) - 8;
+ long offset = data.getLong();
+ if (offset >= srcLen) {
+ throw new IllegalArgumentException("invalid tensor encoding: bad offsets");
+ }
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ byte[] ret = TF_StringDecodeToArray(src, srcLen, status);
+ status.throwExceptionIfNotOK();
+ return ret;
+ }
+ }
- private static native void readNDArray(long handle, Object value);
+ private static void readNDArray(TF_Tensor handle, Object value) {
+ requireHandle(handle);
+ int numDims = TF_NumDims(handle);
+ int dtype = TF_TensorType(handle);
+ Pointer data = TF_TensorData(handle);
+ long sz = TF_TensorByteSize(handle);
+ if (numDims == 0) {
+ throw new IllegalArgumentException(
+ "copyTo() is not meant for scalar Tensors, use the scalar "
+ + "accessor (floatValue(), intValue() etc.) instead");
+ }
+ if (dtype == TF_STRING) {
+ long numElements = 1;
+ for (int i = 0; i < numDims; ++i) {
+ numElements *= TF_Dim(handle, i);
+ }
+ try (PointerScope scope = new PointerScope()) {
+ StringTensorReader reader = new StringTensorReader(handle, numElements);
+ TF_Status status = TF_Status.newStatus();
+ readNDStringArray(reader, numDims, (Object[])value, status);
+ status.throwExceptionIfNotOK();
+ return;
+ }
+ }
+ readNDArray(dtype, new BytePointer(data), sz, numDims, value);
+ }
static {
TensorFlow.init();
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
index a9debb0063d..5abe9f1ffd5 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
@@ -15,10 +15,24 @@
package org.tensorflow;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteLibraryHandle;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetAllOpList;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetOpList;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadLibrary;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version;
+
+import org.bytedeco.javacpp.PointerScope;
+import org.tensorflow.internal.c_api.TF_Buffer;
+import org.tensorflow.internal.c_api.TF_Library;
+import org.tensorflow.internal.c_api.TF_Status;
+
/** Static utility methods describing the TensorFlow runtime. */
public final class TensorFlow {
/** Returns the version of the underlying TensorFlow runtime. */
- public static native String version();
+ public static String version() {
+ return TF_Version().getString();
+ }
/**
* All the TensorFlow operations available in this address space.
@@ -27,7 +41,12 @@ public final class TensorFlow {
* href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList
* protocol buffer, which lists all the available TensorFlow operations.
*/
- public static native byte[] registeredOpList();
+ public static byte[] registeredOpList() {
+ TF_Buffer buf = TF_GetAllOpList();
+ byte[] ret = buf.get();
+ TF_DeleteBuffer(buf);
+ return ret;
+ }
/**
* Load the dynamic library in filename and register the operations and kernels present in that
@@ -40,7 +59,7 @@ public final class TensorFlow {
* @throws UnsatisfiedLinkError if filename cannot be loaded.
*/
public static byte[] loadLibrary(String filename) {
- long h = 0;
+ TF_Library h = null;
try {
h = libraryLoad(filename);
} catch (RuntimeException e) {
@@ -53,11 +72,25 @@ public static byte[] loadLibrary(String filename) {
}
}
- private static native long libraryLoad(String filename);
+ private static TF_Library libraryLoad(String filename) {
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_Library h = TF_LoadLibrary(filename, status);
+ status.throwExceptionIfNotOK();
+ return h;
+ }
+ }
- private static native void libraryDelete(long handle);
+ private static void libraryDelete(TF_Library handle) {
+ if (handle != null && !handle.isNull()) {
+ TF_DeleteLibraryHandle(handle);
+ }
+ }
- private static native byte[] libraryOpList(long handle);
+ private static byte[] libraryOpList(TF_Library handle) {
+ TF_Buffer buf = TF_GetOpList(handle);
+ return buf.get();
+ }
private TensorFlow() {}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlowException.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlowException.java
index 7ff740dfeaa..7d2c943ca24 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlowException.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlowException.java
@@ -17,10 +17,10 @@
/** Unchecked exception thrown when executing TensorFlow Graphs. */
public final class TensorFlowException extends RuntimeException {
- TensorFlowException(String message, Throwable cause) {
- super(message, cause);
- }
- TensorFlowException(String message) {
+ public TensorFlowException(String message) {
super(message);
}
+ public TensorFlowException(String message, Throwable cause) {
+ super(message, cause);
+ }
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java
new file mode 100644
index 00000000000..e0bbff2a32f
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java
@@ -0,0 +1,57 @@
+/*
+ Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
+
+package org.tensorflow.internal.c_api;
+
+import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext;
+import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext;
+
+import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.annotation.Properties;
+
+@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
+public abstract class AbstractTFE_Context extends Pointer {
+ protected static class DeleteDeallocator extends TFE_Context implements Pointer.Deallocator {
+ DeleteDeallocator(TFE_Context s) { super(s); }
+ @Override public void deallocate() { if(!isNull()) TFE_DeleteContext(this); setNull(); }
+ }
+
+ /** References to prevent deallocation. */
+ protected TFE_ContextOptions opts;
+
+ public AbstractTFE_Context(Pointer p) { super(p); }
+
+ /**
+ * Calls TFE_NewContext(), and registers a deallocator.
+ * @return TFE_Context created. Do not call TFE_DeleteContext() on it.
+ */
+ public static TFE_Context newSession(TFE_ContextOptions opts, TF_Status status) {
+ TFE_Context c = TFE_NewContext(opts, status);
+ if (c != null) {
+ c.opts = opts;
+ c.deallocator(new DeleteDeallocator(c));
+ }
+ return c;
+ }
+
+ /**
+ * Calls the deallocator, if registered, otherwise has no effect.
+ */
+ public void delete() {
+ deallocate();
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java
new file mode 100644
index 00000000000..cd9ea29b946
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java
@@ -0,0 +1,54 @@
+/*
+ Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
+
+package org.tensorflow.internal.c_api;
+
+import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContextOptions;
+import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContextOptions;
+
+import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.annotation.Properties;
+
+@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
+public abstract class AbstractTFE_ContextOptions extends Pointer {
+ protected static class DeleteDeallocator extends
+ TFE_ContextOptions implements Pointer.Deallocator {
+ DeleteDeallocator(TFE_ContextOptions s) { super(s); }
+ @Override public void deallocate() { if (!isNull()) TFE_DeleteContextOptions(this); setNull(); }
+ }
+
+ public AbstractTFE_ContextOptions(Pointer p) { super(p); }
+
+ /**
+ * Calls TFE_NewContextOptions(), and registers a deallocator.
+ * @return TFE_ContextOptions created. Do not call TFE_DeleteContextOptions() on it.
+ */
+ public static TFE_ContextOptions newContextOptions() {
+ TFE_ContextOptions o = TFE_NewContextOptions();
+ if (o != null) {
+ o.deallocator(new DeleteDeallocator(o));
+ }
+ return o;
+ }
+
+ /**
+ * Calls the deallocator, if registered, otherwise has no effect.
+ */
+ public void delete() {
+ deallocate();
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java
index b750c53ca53..e776cec5f41 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java
@@ -46,16 +46,22 @@ public static TF_Buffer newBuffer() {
return b;
}
- /** Returns {@code newBufferFromString(new BytePointer(proto)). */
+ /** Returns {@code newBufferFromString(new BytePointer(proto)), or null if proto is null or empty. */
public static TF_Buffer newBufferFromString(byte[] proto) {
+ if (proto == null || proto.length == 0) {
+ return null;
+ }
return newBufferFromString(new BytePointer(proto));
}
/**
* Calls TF_NewBufferFromString(), and registers a deallocator.
- * @return TF_Buffer created. Do not call TF_DeleteBuffer() on it.
+ * @return TF_Buffer created, or null if proto is null or empty. Do not call TF_DeleteBuffer() on it.
*/
public static TF_Buffer newBufferFromString(Pointer proto) {
+ if (proto == null || proto.isNull() || proto.limit() == 0) {
+ return null;
+ }
TF_Buffer b = TF_NewBufferFromString(proto, proto.limit());
if (b != null) {
b.deallocator(new DeleteDeallocator(b));
@@ -63,6 +69,19 @@ public static TF_Buffer newBufferFromString(Pointer proto) {
return b;
}
+ /**
+ * Returns a copy of the data in a Java array, or throws IndexOutOfBoundsException if too large.
+ */
+ public byte[] get() {
+ long length = ((TF_Buffer)this).length();
+ if (length > Integer.MAX_VALUE) {
+ throw new IndexOutOfBoundsException("TF_Buffer is too large to serialize into a byte[] array");
+ }
+ byte[] data = new byte[(int)length];
+ new BytePointer(((TF_Buffer)this).data()).get(data);
+ return data;
+ }
+
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java
index 776311252d6..126acc1afbf 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java
@@ -17,23 +17,36 @@
package org.tensorflow.internal.c_api;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
+import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.annotation.Properties;
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public abstract class AbstractTF_Session extends Pointer {
protected static class DeleteDeallocator extends TF_Session implements Pointer.Deallocator {
DeleteDeallocator(TF_Session s) { super(s); }
- @Override public void deallocate() { if (!isNull()) TF_DeleteSession(this, TF_Status
- .newStatus()); setNull(); }
+ @Override public void deallocate() {
+ if (!isNull()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_CloseSession(this, status);
+ // Result of close is ignored, delete anyway.
+ TF_DeleteSession(this, status);
+ setNull();
+ }
+ }
}
/** References to prevent deallocation. */
protected TF_Graph graph;
protected TF_SessionOptions opts;
+ protected TF_Buffer run_options;
+ protected TF_Buffer meta_graph_def;
protected TF_Status status;
public AbstractTF_Session(Pointer p) { super(p); }
@@ -53,6 +66,25 @@ public static TF_Session newSession(TF_Graph graph, TF_SessionOptions opts, TF_S
return s;
}
+ /**
+ * Calls TF_LoadSessionFromSavedModel(), and registers a deallocator.
+ * @return TF_Session created. Do not call TF_DeleteSession() on it.
+ */
+ public static TF_Session loadSessionFromSavedModel(TF_SessionOptions session_options, TF_Buffer run_options,
+ String export_dir, String[] tags, TF_Graph graph, TF_Buffer meta_graph_def, TF_Status status) {
+ TF_Session s = TF_LoadSessionFromSavedModel(session_options, run_options,
+ new BytePointer(export_dir), new PointerPointer(tags), tags.length, graph, meta_graph_def, status);
+ if (s != null) {
+ s.graph = graph;
+ s.opts = session_options;
+ s.run_options = run_options;
+ s.meta_graph_def = meta_graph_def;
+ s.status = status;
+ s.deallocator(new DeleteDeallocator(s));
+ }
+ return s;
+ }
+
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java
index ccfb7a7d84c..28895708e72 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java
@@ -18,10 +18,21 @@
package org.tensorflow.internal.c_api;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteStatus;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_FAILED_PRECONDITION;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetCode;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_INVALID_ARGUMENT;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_Message;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewStatus;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_OK;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_OUT_OF_RANGE;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_PERMISSION_DENIED;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_RESOURCE_EXHAUSTED;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_UNAUTHENTICATED;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_UNIMPLEMENTED;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.annotation.Properties;
+import org.tensorflow.TensorFlowException;
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public abstract class AbstractTF_Status extends Pointer {
@@ -50,4 +61,27 @@ public static TF_Status newStatus() {
public void delete() {
deallocate();
}
+
+ /** Map TF_Code to unchecked exception, and throw if not TF_OK. */
+ public void throwExceptionIfNotOK() {
+ TF_Status s = (TF_Status)this;
+ switch (TF_GetCode(s)) {
+ case TF_OK:
+ break;
+ case TF_INVALID_ARGUMENT:
+ throw new IllegalArgumentException(TF_Message(s).getString());
+ case TF_UNAUTHENTICATED:
+ case TF_PERMISSION_DENIED:
+ throw new SecurityException(TF_Message(s).getString());
+ case TF_RESOURCE_EXHAUSTED:
+ case TF_FAILED_PRECONDITION:
+ throw new IllegalStateException(TF_Message(s).getString());
+ case TF_OUT_OF_RANGE:
+ throw new IndexOutOfBoundsException(TF_Message(s).getString());
+ case TF_UNIMPLEMENTED:
+ throw new UnsupportedOperationException(TF_Message(s).getString());
+ default:
+ throw new TensorFlowException(TF_Message(s).getString());
+ }
+ }
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java
index 6a0fae0f431..a46af633112 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java
@@ -22,7 +22,6 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewTensor;
import org.bytedeco.javacpp.Pointer;
-import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.annotation.Properties;
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
@@ -42,11 +41,6 @@ protected static class DeleteDeallocator extends TF_Tensor implements Pointer.De
public AbstractTF_Tensor(Pointer p) { super(p); }
- // WARNING: This is a temporary hack to create a `TF_Tensor` object out of the Tensor native handle
- public void temporaryHackToSetAddressFromHandle(long tensorNativeHandle) {
- this.address = tensorNativeHandle;
- }
-
/**
* Calls TF_NewTensor(), and registers a deallocator.
* @return TF_Tensor created. Do not call TF_DeleteTensor() on it.
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java
index 0c2ca424022..27b6f17b467 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java
@@ -21,6 +21,7 @@
import org.bytedeco.javacpp.ClassProperties;
import org.bytedeco.javacpp.LoadEnabled;
import org.bytedeco.javacpp.Loader;
+import org.bytedeco.javacpp.annotation.NoException;
import org.bytedeco.javacpp.annotation.Platform;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.javacpp.tools.Info;
@@ -91,6 +92,7 @@
},
target = "org.tensorflow.internal.c_api",
global = "org.tensorflow.internal.c_api.global.tensorflow")
+@NoException
public class tensorflow implements LoadEnabled, InfoMapper {
@Override public void init(ClassProperties properties) {
@@ -181,12 +183,13 @@ public void map(InfoMap infoMap) {
.put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status"))
.put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer"))
.put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor"))
+ .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session"))
.put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions").base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions"))
.put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph"))
.put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();"))
.put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();"))
.put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions").base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions"))
- .put(new Info("TF_Operation", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell",
+ .put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell",
"TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2",
"TFE_MonitoringIntGaugeCell", "TFE_MonitoringStringGaugeCell", "TFE_MonitoringBoolGaugeCell",
"TFE_MonitoringIntGauge0", "TFE_MonitoringIntGauge1", "TFE_MonitoringIntGauge2",
@@ -199,10 +202,10 @@ public void map(InfoMap infoMap) {
.put(new Info("TFE_MonitoringIntGaugeCell::cell").javaText("public native @MemberGetter @ByRef IntGaugeCell cell();"))
.put(new Info("TFE_MonitoringStringGaugeCell::cell").javaText("public native @MemberGetter @ByRef StringGaugeCell cell();"))
.put(new Info("TFE_MonitoringBoolGaugeCell::cell").javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();"))
+ .put(new Info("TFE_Context").pointerTypes("TFE_Context").base("org.tensorflow.internal.c_api.AbstractTFE_Context"))
+ .put(new Info("TFE_ContextOptions").pointerTypes("TFE_ContextOptions").base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions"))
.put(new Info("TFE_Context::context").javaText("@MemberGetter public native @ByRef EagerContext context();"))
.put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();"))
- .put(new Info("TF_ShapeInferenceContextDimValueKnown", "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip())
- .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session"))
- .put(new Info("TF_WhileParams").purify());
+ .put(new Info("TF_ShapeInferenceContextDimValueKnown", "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip());
}
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.cc
deleted file mode 100644
index c8086d71ab3..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.cc
+++ /dev/null
@@ -1,335 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/eager_operation_builder_jni.h"
-
-#include
-#include
-#include
-
-#include "tensorflow/c/eager/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-// This value should be >= to the maximum number of outputs in any op
-#define MAX_OUTPUTS_PER_OP 8
-
-namespace {
-
-TFE_Op* requireOp(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "Operation has already been built");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TFE_Context* requireContext(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException, "Context has been deleted");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "close() has been called on the Tensor");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "Tensor handle has been deleted");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
- JNIEnv* env, jclass clazz, jlong context_handle, jstring name) {
- TFE_Context* context = requireContext(env, context_handle);
- if (context == nullptr) return 0;
- const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
- TF_Status* status = TF_NewStatus();
- TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
- env->ReleaseStringUTFChars(name, op_or_function_name);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
- "Cannot represent a C TFE_Op as a Java long");
- return reinterpret_cast(op);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete(
- JNIEnv* env, jclass clazz, jlong op_handle) {
- if (op_handle == 0) return;
- TFE_DeleteOp(reinterpret_cast(op_handle));
-}
-
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
- JNIEnv* env, jclass clazz, jlong op_handle) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return 0;
- int num_retvals = MAX_OUTPUTS_PER_OP;
- std::unique_ptr retvals(
- new TFE_TensorHandle*[num_retvals]);
- TF_Status* status = TF_NewStatus();
- TFE_Execute(op, retvals.get(), &num_retvals, status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return nullptr;
- }
- TF_DeleteStatus(status);
- jlongArray rethandles = env->NewLongArray(num_retvals);
- if (num_retvals > 0) {
- jlong* retval = env->GetLongArrayElements(rethandles, nullptr);
- for (int i = 0; i < num_retvals; ++i) {
- retval[i] = reinterpret_cast(retvals[i]);
- }
- env->ReleaseLongArrayElements(rethandles, retval, 0);
- }
- return rethandles;
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
- JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- const char* cname = env->GetStringUTFChars(device_name, nullptr);
- TF_Status* status = TF_NewStatus();
- TFE_OpSetDevice(op, cname, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(device_name, cname);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
- JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle);
- if (tensor_handle == nullptr) return;
- TF_Status* status = TF_NewStatus();
- TFE_OpAddInput(op, tensor_handle, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
- JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr);
- size_t num_inputs = static_cast(env->GetArrayLength(input_handles));
- std::unique_ptr tensor_handles(
- new TFE_TensorHandle*[num_inputs]);
- for (int i = 0; i < num_inputs; ++i) {
- tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]);
- if (tensor_handles[i] == nullptr) {
- env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
- return;
- }
- }
- env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
- TF_Status* status = TF_NewStatus();
- TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
- JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
- jbyteArray value) {
- static_assert(sizeof(jbyte) == 1,
- "Require Java byte to be represented as a single byte");
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- const char* cname = env->GetStringUTFChars(attr_name, nullptr);
- jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
- TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value));
- env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
- env->ReleaseStringUTFChars(attr_name, cname);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(
- JNIEnv* env, jclass object, jlong op_handle, jstring attr_name,
- jobjectArray values) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- const char* cname = env->GetStringUTFChars(attr_name, nullptr);
- int num_values = env->GetArrayLength(values);
- static_assert(sizeof(jbyte) == 1,
- "Require Java byte to be represented as a single byte");
- std::unique_ptr jarrays(new jbyteArray[num_values]);
- std::unique_ptr jvalues(new jbyte*[num_values]);
- std::unique_ptr cvalues(new void*[num_values]);
- std::unique_ptr lengths(new size_t[num_values]);
-
- for (int i = 0; i < num_values; ++i) {
- jbyteArray v =
- static_cast(env->GetObjectArrayElement(values, i));
- jarrays[i] = v;
- jvalues[i] = env->GetByteArrayElements(v, nullptr);
- cvalues[i] = jvalues[i];
- lengths[i] = static_cast(env->GetArrayLength(v));
- }
- TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values);
- for (int i = 0; i < num_values; ++i) {
- env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
- }
- env->ReleaseStringUTFChars(attr_name, cname);
-}
-
-#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
- JNIEXPORT void JNICALL \
- Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \
- JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
- jtype value) { \
- static_assert( \
- sizeof(ctype) >= sizeof(jtype), \
- "Information loss when converting between Java and C types"); \
- TFE_Op* op = requireOp(env, op_handle); \
- if (op == nullptr) return; \
- const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
- TFE_OpSetAttr##name(op, cname, static_cast(value)); \
- env->ReleaseStringUTFChars(attr_name, cname); \
- }
-
-#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
- JNIEXPORT void JNICALL \
- Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \
- JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
- jtype##Array value) { \
- TFE_Op* op = requireOp(env, op_handle); \
- if (op == nullptr) return; \
- const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
- /* Make a copy of the array to paper over any differences */ \
- /* in byte representations of the jtype and ctype */ \
- /* For example, jint vs TF_DataType. */ \
- /* If this copy turns out to be a problem in practice */ \
- /* can avoid it for many types. */ \
- const int n = env->GetArrayLength(value); \
- std::unique_ptr cvalue(new ctype[n]); \
- jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
- for (int i = 0; i < n; ++i) { \
- cvalue[i] = static_cast(elems[i]); \
- } \
- TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \
- env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
- env->ReleaseStringUTFChars(attr_name, cname); \
- }
-
-#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
- DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
- DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
-
-DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
-DEFINE_SET_ATTR(Float, Float, jfloat, float);
-DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
-DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
-#undef DEFINE_SET_ATTR
-#undef DEFINE_SET_ATTR_LIST
-#undef DEFINE_SET_ATTR_SCALAR
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
- JNIEnv* env, jclass clazz, jlong handle, jstring attr_name,
- jlong tensor_handle) {
- TFE_Op* op = requireOp(env, handle);
- if (op == nullptr) return;
- TF_Tensor* t = requireTensor(env, tensor_handle);
- if (t == nullptr) return;
- const char* cname = env->GetStringUTFChars(attr_name, nullptr);
- TF_Status* status = TF_NewStatus();
- TFE_OpSetAttrTensor(op, cname, t, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(attr_name, cname);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
- JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
- jlongArray shape, jint num_dims) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- std::unique_ptr cvalue;
- // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
- // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
- if (num_dims > 0) {
- cvalue.reset(new int64_t[num_dims]);
- jlong* elems = env->GetLongArrayElements(shape, nullptr);
- for (int i = 0; i < num_dims; ++i) {
- cvalue[i] = static_cast(elems[i]);
- }
- env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
- }
- const char* cname = env->GetStringUTFChars(attr_name, nullptr);
- TF_Status* status = TF_NewStatus();
- TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast(num_dims),
- status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(attr_name, cname);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(
- JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
- jlongArray shapes, jintArray num_dims) {
- TFE_Op* op = requireOp(env, op_handle);
- if (op == nullptr) return;
- std::unique_ptr cshapes;
- std::unique_ptr cdims;
- std::unique_ptr cnum_dims;
- const int num_dims_length = env->GetArrayLength(num_dims);
- if (num_dims_length > 0) {
- const int shapes_length = env->GetArrayLength(shapes);
- cshapes.reset(new int64_t[shapes_length]);
- cdims.reset(new const int64_t*[num_dims_length]);
- cnum_dims.reset(new int[num_dims_length]);
- jlong* shapes_elems =
- static_cast(env->GetPrimitiveArrayCritical(shapes, nullptr));
- std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
- env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
- int64_t* cshapes_ptr = cshapes.get();
- jint* num_dims_elems =
- static_cast(env->GetPrimitiveArrayCritical(num_dims, nullptr));
- for (int i = 0; i < num_dims_length; ++i) {
- cnum_dims[i] = static_cast(num_dims_elems[i]);
- cdims[i] = cshapes_ptr;
- if (cnum_dims[i] > 0) {
- cshapes_ptr += cnum_dims[i];
- }
- }
- env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
- }
- const char* cname = env->GetStringUTFChars(attr_name, nullptr);
- TF_Status* status = TF_NewStatus();
- TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(),
- num_dims_length, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(attr_name, cname);
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.h
deleted file mode 100644
index 6da891d7ae2..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_builder_jni.h
+++ /dev/null
@@ -1,191 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: allocate
- * Signature: (JLjava/lang/String;)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
- JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: execute
- * Signature: (J)[J
- */
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: addInput
- * Signature: (JJ)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
- JNIEnv *, jclass, jlong, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: addInputList
- * Signature: (J[J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
- JNIEnv *, jclass, jlong, jlongArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setDevice
- * Signature: (JLjava/lang/String;)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
- JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrString
- * Signature: (JLjava/lang/String;[B)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
- JNIEnv *, jclass, jlong, jstring, jbyteArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrStringList
- * Signature: (JLjava/lang/String;[L)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv *, jclass,
- jlong, jstring,
- jobjectArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrInt
- * Signature: (JLjava/lang/String;J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrInt(
- JNIEnv *, jclass, jlong, jstring, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrIntList
- * Signature: (JLjava/lang/String;[J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrIntList(
- JNIEnv *, jclass, jlong, jstring, jlongArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrFloat
- * Signature: (JLjava/lang/String;F)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrFloat(
- JNIEnv *, jclass, jlong, jstring, jfloat);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrFloatList
- * Signature: (JLjava/lang/String;[F)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrFloatList(JNIEnv *, jclass,
- jlong, jstring,
- jfloatArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrBool
- * Signature: (JLjava/lang/String;Z)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrBool(
- JNIEnv *, jclass, jlong, jstring, jboolean);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrBoolList
- * Signature: (JLjava/lang/String;[Z)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrBoolList(JNIEnv *, jclass,
- jlong, jstring,
- jbooleanArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrType
- * Signature: (JLjava/lang/String;I)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrType(
- JNIEnv *, jclass, jlong, jstring, jint);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrTypeList
- * Signature: (JLjava/lang/String;[I)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrTypeList(JNIEnv *, jclass,
- jlong, jstring,
- jintArray);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrTensor
- * Signature: (JLjava/lang/String;J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
- JNIEnv *, jclass, jlong, jstring, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrShape
- * Signature: (JLjava/lang/String;[JI)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
- JNIEnv *, jclass, jlong, jstring, jlongArray, jint);
-
-/*
- * Class: org_tensorflow_EagerOperationBuilder
- * Method: setAttrShapeList
- * Signature: (JLjava/lang/String;[J[I)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv *, jclass,
- jlong, jstring,
- jlongArray,
- jintArray);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.cc
deleted file mode 100644
index fb0d1c46751..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.cc
+++ /dev/null
@@ -1,146 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/eager_operation_jni.h"
-
-#include
-#include
-#include
-
-#include
-#include
-
-#include "tensorflow/c/eager/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-namespace {
-
-TFE_Op* requireOp(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "Eager session has been closed");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException, "EagerSession has been closed");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-} // namespace
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- if (handle == 0) return;
- TFE_DeleteOp(reinterpret_cast(handle));
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
- JNIEnv* env, jclass clazz, jlong handle) {
- if (handle == 0) return;
- TFE_DeleteTensorHandle(reinterpret_cast(handle));
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
- JNIEnv* env, jclass clazz, jlong handle) {
- TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
- if (tensor_handle == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
- "Cannot represent a C TF_Tensor as a Java long");
- return reinterpret_cast(tensor);
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
- JNIEnv* env, jclass clazz, jlong handle, jstring name) {
- TFE_Op* op = requireOp(env, handle);
- if (op == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- const char* cname = env->GetStringUTFChars(name, nullptr);
- int length = TFE_OpGetOutputLength(op, cname, status);
- env->ReleaseStringUTFChars(name, cname);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- return static_cast(length);
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
- JNIEnv* env, jclass clazz, jlong handle, jstring name) {
- TFE_Op* op = requireOp(env, handle);
- if (op == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- const char* cname = env->GetStringUTFChars(name, nullptr);
- int length = TFE_OpGetInputLength(op, cname, status);
- env->ReleaseStringUTFChars(name, cname);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- return static_cast(length);
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
- JNIEnv* env, jclass clazz, jlong handle) {
- TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
- if (tensor_handle == nullptr) return 0;
- TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle);
- return static_cast(data_type);
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
- JNIEnv* env, jclass clazz, jlong handle) {
- TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
- if (tensor_handle == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- int num_dims = TFE_TensorHandleNumDims(tensor_handle, status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- return static_cast(num_dims);
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint dim_index) {
- TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
- if (tensor_handle == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- return static_cast(dim);
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.h
deleted file mode 100644
index ef38ed038c9..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_operation_jni.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_EagerOperation
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *,
- jclass, jlong);
-
-/*
- * Class: org_tensorflow_EagerOperation
- * Method: deleteTensorHandle
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: resolveTensorHandle
- * Signature: (J)J
- */
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_EagerOperation_resolveTensorHandle(JNIEnv *, jclass, jlong);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: outputListLength
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
- JNIEnv *, jclass, jlong, jstring);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: inputListLength
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
- JNIEnv *, jclass, jlong, jstring);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: dataType
- * Signature: (J)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(JNIEnv *,
- jclass,
- jlong);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: numDims
- * Signature: (J)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(JNIEnv *,
- jclass,
- jlong);
-
-/**
- * Class: org_tensorflow_EagerOperation
- * Method: dim
- * Signature: (JI)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv *, jclass,
- jlong, jint);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.cc
deleted file mode 100644
index 852af6fb43f..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/eager_session_jni.h"
-
-#include
-#include
-
-#include "tensorflow/c/eager/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
- JNIEnv* env, jclass clazz, jboolean async, jint dpp, jbyteArray config) {
- TFE_ContextOptions* opts = TFE_NewContextOptions();
- jbyte* cconfig = nullptr;
- TF_Status* status = TF_NewStatus();
- if (config != nullptr) {
- cconfig = env->GetByteArrayElements(config, nullptr);
- TFE_ContextOptionsSetConfig(
- opts, cconfig, static_cast(env->GetArrayLength(config)),
- status);
- if (!throwExceptionIfNotOK(env, status)) {
- env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
- TFE_DeleteContextOptions(opts);
- TF_DeleteStatus(status);
- return 0;
- }
- }
- TFE_ContextOptionsSetAsync(opts, static_cast(async));
- TFE_ContextOptionsSetDevicePlacementPolicy(
- opts, static_cast(dpp));
- TFE_Context* context = TFE_NewContext(opts, status);
- TFE_DeleteContextOptions(opts);
- if (config != nullptr) {
- env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
- }
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- static_assert(sizeof(jlong) >= sizeof(TFE_Context*),
- "Cannot represent a C TFE_Op as a Java long");
- return reinterpret_cast(context);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- if (handle == 0) return;
- TFE_DeleteContext(reinterpret_cast(handle));
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.h
deleted file mode 100644
index 9f7bdaccd36..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/eager_session_jni.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_EagerSession
- * Method: allocate
- * Signature: (ZI[B)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
- JNIEnv *env, jclass clazz, jboolean async, jint dpp, jbyteArray config);
-
-/*
- * Class: org_tensorflow_EagerSession
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv *, jclass,
- jlong);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.cc
deleted file mode 100644
index 7b1d6508bd7..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include
-#include
-#include
-
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
-const char kIllegalStateException[] = "java/lang/IllegalStateException";
-const char kNullPointerException[] = "java/lang/NullPointerException";
-const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException";
-const char kUnsupportedOperationException[] =
- "java/lang/UnsupportedOperationException";
-
-void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
- va_list args;
- va_start(args, fmt);
- // Using vsnprintf() instead of vasprintf() because the latter doesn't seem to
- // be easily available on Windows.
- const size_t max_msg_len = 512;
- char* message = static_cast(malloc(max_msg_len));
- if (vsnprintf(message, max_msg_len, fmt, args) >= 0) {
- env->ThrowNew(env->FindClass(clazz), message);
- } else {
- env->ThrowNew(env->FindClass(clazz), "");
- }
- free(message);
- va_end(args);
-}
-
-namespace {
-// Map TF_Codes to unchecked exceptions.
-const char* exceptionClassName(TF_Code code) {
- switch (code) {
- case TF_OK:
- return nullptr;
- case TF_INVALID_ARGUMENT:
- return kIllegalArgumentException;
- case TF_UNAUTHENTICATED:
- case TF_PERMISSION_DENIED:
- return "java/lang/SecurityException";
- case TF_RESOURCE_EXHAUSTED:
- case TF_FAILED_PRECONDITION:
- return kIllegalStateException;
- case TF_OUT_OF_RANGE:
- return kIndexOutOfBoundsException;
- case TF_UNIMPLEMENTED:
- return kUnsupportedOperationException;
- default:
- return "org/tensorflow/TensorFlowException";
- }
-}
-} // namespace
-
-bool throwExceptionIfNotOK(JNIEnv* env, const TF_Status* status) {
- const char* clazz = exceptionClassName(TF_GetCode(status));
- if (clazz == nullptr) return true;
- env->ThrowNew(env->FindClass(clazz), TF_Message(status));
- return false;
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.h
deleted file mode 100644
index 465281f804e..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/exception_jni.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-struct TF_Status;
-
-extern const char kIllegalArgumentException[];
-extern const char kIllegalStateException[];
-extern const char kNullPointerException[];
-extern const char kIndexOutOfBoundsException[];
-extern const char kUnsupportedOperationException[];
-
-void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...);
-
-// If status is not TF_OK, then throw an appropriate exception.
-// Returns true iff TF_GetCode(status) == TF_OK.
-bool throwExceptionIfNotOK(JNIEnv* env, const TF_Status* status);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.cc
deleted file mode 100644
index e50d1b5dfa6..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.cc
+++ /dev/null
@@ -1,335 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/graph_jni.h"
-
-#include
-#include
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-#include "src/main/native/utils_jni.h"
-
-namespace {
-template
-T* requireHandleImpl(JNIEnv* env, jlong handle) {
- static_assert(sizeof(jlong) >= sizeof(T*),
- "Cannot package C object pointers as a Java long");
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "close() has been called on the Graph");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
- return requireHandleImpl(env, handle);
-}
-
-TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
- return requireHandleImpl(env, handle);
-}
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv*, jclass) {
- return reinterpret_cast(TF_NewGraph());
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
- jlong handle) {
- if (handle == 0) return;
- TF_DeleteGraph(reinterpret_cast(handle));
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jstring name) {
- TF_Graph* g = requireHandle(env, handle);
- if (g == nullptr) return 0;
- const char* cname = env->GetStringUTFChars(name, nullptr);
- TF_Operation* op = TF_GraphOperationByName(g, cname);
- env->ReleaseStringUTFChars(name, cname);
- return reinterpret_cast(op);
-}
-
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
- JNIEnv* env, jclass clazz, jlong handle, jint position) {
- TF_Graph* g = requireHandle(env, handle);
- if (g == nullptr) return nullptr;
-
- size_t pos = static_cast(position);
- TF_Operation* operation = TF_GraphNextOperation(g, &pos);
- if (operation == nullptr) return nullptr;
-
- jlong handle_and_position[2];
- handle_and_position[0] = reinterpret_cast(operation);
- handle_and_position[1] = static_cast(pos);
-
- jlongArray rhett = env->NewLongArray(2);
- env->SetLongArrayRegion(rhett, 0, 2, handle_and_position);
- return rhett;
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
- JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
- jstring prefix) {
- TF_Graph* g = requireHandle(env, handle);
- if (g == nullptr) return;
-
- TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
-
- jboolean is_copy;
- const char* cprefix = env->GetStringUTFChars(prefix, &is_copy);
- TF_ImportGraphDefOptionsSetPrefix(opts, cprefix);
- env->ReleaseStringUTFChars(prefix, cprefix);
-
- static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
- jbyte* bytes = env->GetByteArrayElements(graph_def, &is_copy);
- TF_Buffer* buf =
- TF_NewBufferFromString(bytes, env->GetArrayLength(graph_def));
- TF_Status* status = TF_NewStatus();
-
- TF_GraphImportGraphDef(g, buf, opts, status);
- throwExceptionIfNotOK(env, status);
- // Continue cleaning up resources even if an exception was thrown.
-
- TF_DeleteStatus(status);
- TF_DeleteBuffer(buf);
- env->ReleaseByteArrayElements(graph_def, bytes, JNI_ABORT);
-
- TF_DeleteImportGraphDefOptions(opts);
-}
-
-JNIEXPORT jbyteArray JNICALL
-Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
- jbyteArray ret = nullptr;
- TF_Graph* g = requireHandle(env, handle);
- if (g == nullptr) return ret;
-
- TF_Buffer* buf = TF_NewBuffer();
- TF_Status* status = TF_NewStatus();
- TF_GraphToGraphDef(g, buf, status);
- if (throwExceptionIfNotOK(env, status)) {
- // sizeof(jsize) is less than sizeof(size_t) on some platforms.
- if (buf->length > std::numeric_limits::max()) {
- throwException(env, kIndexOutOfBoundsException,
- "GraphDef is too large to serialize into a byte[] array");
- } else {
- static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
- jint ret_len = static_cast(buf->length);
- ret = env->NewByteArray(ret_len);
- env->SetByteArrayRegion(ret, 0, ret_len,
- static_cast(buf->data));
- }
- }
- TF_DeleteStatus(status);
- TF_DeleteBuffer(buf);
- return ret;
-}
-
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
- JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
- jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
- jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
- TF_Graph* g = requireHandle(env, handle);
- if (g == nullptr) return nullptr;
-
- const jint ny = env->GetArrayLength(y_handles);
- const jint nx = env->GetArrayLength(x_handles);
-
- std::unique_ptr y(new TF_Output[ny]);
- std::unique_ptr x(new TF_Output[nx]);
- std::unique_ptr dx(nullptr);
- std::unique_ptr dy(new TF_Output[nx]);
-
- resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
- resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
- if (dx_handles != nullptr) {
- if (env->GetArrayLength(dx_handles) != ny) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d dx handles", ny,
- env->GetArrayLength(dx_handles));
- }
- dx.reset(new TF_Output[ny]);
- resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
- }
- if (env->ExceptionCheck()) return nullptr;
-
- const char* cprefix = nullptr;
- if (prefix != nullptr) {
- cprefix = env->GetStringUTFChars(prefix, nullptr);
- }
- TF_Status* status = TF_NewStatus();
- TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
- status, dy.get());
- if (prefix != nullptr) {
- env->ReleaseStringUTFChars(prefix, cprefix);
- }
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return nullptr;
- }
- TF_DeleteStatus(status);
-
- // returned array contains both op handles and output indices, in pair
- jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
- jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
- for (int i = 0, j = nx; i < nx; ++i, ++j) {
- TF_Output dy_output = dy.get()[i];
- dy_elems[i] = reinterpret_cast(dy_output.oper);
- dy_elems[j] = static_cast(dy_output.index);
- }
- env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
-
- return dy_handles_and_indices;
-}
-
-// helper function for while loop -- constructs conditional or body subgraph
-jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
- TF_Graph* const subgraph,
- const TF_Output* const inputs,
- const TF_Output* const outputs, const int ninputs,
- const int noutputs) {
- jmethodID build_subgraph_method_id = env->GetStaticMethodID(
- clazz, "buildSubgraph",
- "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
- if (build_subgraph_method_id == 0) return nullptr;
-
- jlong subgraph_handle = reinterpret_cast(subgraph);
-
- jlongArray input_handles = env->NewLongArray(ninputs);
- jintArray input_indices = env->NewIntArray(ninputs);
- jlongArray output_handles = env->NewLongArray(noutputs);
- jintArray output_indices = env->NewIntArray(noutputs);
-
- jlong* input_handles_elems =
- env->GetLongArrayElements(input_handles, nullptr);
- jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
- jlong* output_handles_elems =
- env->GetLongArrayElements(output_handles, nullptr);
- jint* output_indices_elems =
- env->GetIntArrayElements(output_indices, nullptr);
-
- for (int i = 0; i < ninputs; ++i) {
- input_handles_elems[i] = reinterpret_cast((inputs[i]).oper);
- input_indices_elems[i] = static_cast((inputs[i]).index);
- }
-
- for (int i = 0; i < noutputs; ++i) {
- output_handles_elems[i] = reinterpret_cast((outputs[i]).oper);
- output_indices_elems[i] = static_cast((outputs[i]).index);
- }
-
- env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
- env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
- env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
- env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
-
- // call Java code to construct the subgraph
- jlongArray output_handles_and_indices =
- (jlongArray)env->CallStaticObjectMethod(
- clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
- input_handles, input_indices, output_handles, output_indices);
-
- if (env->ExceptionOccurred()) {
- env->ExceptionDescribe();
- return nullptr;
- }
-
- // returned array contains both op handles and output indices, in pair
- return output_handles_and_indices;
-}
-
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
- JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
- jintArray input_indices, jstring name, jobject cond_graph_builder,
- jobject body_graph_builder) {
- TF_Graph* g = requireHandle(env, handle);
- TF_Status* status = TF_NewStatus();
- if (g == nullptr) return nullptr;
-
- int ninputs = env->GetArrayLength(input_handles);
-
- std::unique_ptr inputs(new TF_Output[ninputs]);
- resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
- ninputs);
- if (env->ExceptionCheck()) return nullptr;
-
- // initialize while params
- TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
- throwExceptionIfNotOK(env, status);
-
- // build conditional subgraph
- jlongArray cond_output_handles_and_indices =
- buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
- params.cond_inputs, ¶ms.cond_output, params.ninputs, 1);
-
- // build body subgraph
- jlongArray body_output_handles_and_indices = buildSubgraph(
- env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
- params.body_outputs, params.ninputs, params.ninputs);
-
- if (cond_output_handles_and_indices == nullptr ||
- body_output_handles_and_indices == nullptr)
- return nullptr;
-
- // set cond_output param to output of the conditional subgraph
- jlong* cond_output_elems =
- env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
- TF_Operation* cond_output_op =
- requireOperationHandle(env, cond_output_elems[0]);
- params.cond_output = {cond_output_op,
- static_cast(cond_output_elems[1])};
- env->ReleaseLongArrayElements(cond_output_handles_and_indices,
- cond_output_elems, 0);
-
- // set body_outputs param to outputs of the body subgraph
- jlong* body_output_elems =
- env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
- for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
- TF_Operation* body_output_op =
- requireOperationHandle(env, body_output_elems[i]);
- params.body_outputs[i] = {body_output_op,
- static_cast(body_output_elems[j])};
- }
- env->ReleaseLongArrayElements(body_output_handles_and_indices,
- body_output_elems, 0);
-
- // set loop name param
- params.name = env->GetStringUTFChars(name, 0);
-
- // build the while loop, storing loop outputs in `outputs`
- std::unique_ptr outputs(new TF_Output[ninputs]);
- TF_FinishWhile(¶ms, status, outputs.get());
-
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
-
- env->ReleaseStringUTFChars(name, params.name);
-
- // returned array contains both op handles and output indices, in pair
- jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
- jlong* output_elems =
- env->GetLongArrayElements(output_handles_and_indices, nullptr);
- for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
- TF_Output output = outputs.get()[i];
- output_elems[i] = reinterpret_cast(output.oper);
- output_elems[j] = static_cast(output.index);
- }
- env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
-
- return output_handles_and_indices;
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.h
deleted file mode 100644
index 4281297dca2..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_jni.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_Graph
- * Method: allocate
- * Signature: ()J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv *, jclass);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv *, jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: operation
- * Signature: (JLjava/lang/String;)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
- jlong, jstring);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: operations
- * Signature: (JI)[J
- */
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
- jclass,
- jlong,
- jint);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: importGraphDef
- * Signature: (J[BLjava/lang/String;)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(JNIEnv *,
- jclass, jlong,
- jbyteArray,
- jstring);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: toGraphDef
- * Signature: (J)[B
- */
-JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
- jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: name
- * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
- */
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
- JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
- jintArray, jlongArray, jintArray);
-
-/*
- * Class: org_tensorflow_Graph
- * Method: whileLoop
- * Signature:
- * (J[J[IILjava/lang/String;Lorg/tensorflow/Graph/WhileSubgraphBuilder;Lorg/tensorflow/Graph/WhileSubgraphBuilder;)[J
- */
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
- JNIEnv *, jclass, jlong, jlongArray, jintArray, jstring, jobject, jobject);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.cc
deleted file mode 100644
index dda2b4209ad..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.cc
+++ /dev/null
@@ -1,335 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/graph_operation_builder_jni.h"
-#include
-#include
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-namespace {
-TF_OperationDescription* requireHandle(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "Operation has already been built");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-bool resolveOutput(JNIEnv* env, jlong op_handle, jint index, TF_Output* out) {
- if (op_handle == 0) {
- throwException(env, kIllegalStateException,
- "close() was called on the Graph");
- return false;
- }
- out->oper = reinterpret_cast(op_handle);
- out->index = static_cast(index);
- return true;
-}
-
-TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "close() has been called on the Tensor");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate(
- JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) {
- if (graph_handle == 0) {
- throwException(env, kIllegalStateException,
- "close() has been called on the Graph");
- return 0;
- }
- TF_Graph* graph = reinterpret_cast(graph_handle);
- const char* op_type = env->GetStringUTFChars(type, nullptr);
- const char* op_name = env->GetStringUTFChars(name, nullptr);
- TF_OperationDescription* d = TF_NewOperation(graph, op_type, op_name);
- env->ReleaseStringUTFChars(name, op_name);
- env->ReleaseStringUTFChars(type, op_type);
- static_assert(sizeof(jlong) >= sizeof(TF_OperationDescription*),
- "Cannot represent a C TF_OperationDescription as a Java long");
- return reinterpret_cast(d);
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_finish(
- JNIEnv* env, jclass clazz, jlong handle) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return 0;
- TF_Status* status = TF_NewStatus();
- TF_Operation* op = TF_FinishOperation(d, status);
- if (throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return reinterpret_cast(op);
- }
- TF_DeleteStatus(status);
- return 0;
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput(
- JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) {
- TF_Output out;
- if (!resolveOutput(env, op_handle, index, &out)) return;
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- TF_AddInput(d, out);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList(
- JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles,
- jintArray indices) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- const size_t n = static_cast(env->GetArrayLength(op_handles));
- if (env->GetArrayLength(indices) != n) {
- throwException(env, kIllegalArgumentException,
- "mismatch in number of Operations (%d) and output indices "
- "(%d) provided",
- n, env->GetArrayLength(indices));
- return;
- }
- std::unique_ptr o(new TF_Output[n]);
- jlong* oph = env->GetLongArrayElements(op_handles, nullptr);
- jint* idx = env->GetIntArrayElements(indices, nullptr);
- bool ok = true;
- for (int i = 0; i < n && ok; ++i) {
- ok = resolveOutput(env, oph[i], idx[i], &o[i]);
- }
- env->ReleaseIntArrayElements(indices, idx, JNI_ABORT);
- env->ReleaseLongArrayElements(op_handles, oph, JNI_ABORT);
- if (!ok) return;
- TF_AddInputList(d, o.get(), n);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jlong op_handle) {
- if (op_handle == 0) {
- throwException(env, kIllegalStateException,
- "control input is not valid, "
- "perhaps the Graph containing it has been closed()?");
- return;
- }
- TF_Operation* control = reinterpret_cast(op_handle);
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- TF_AddControlInput(d, control);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice(
- JNIEnv* env, jclass clazz, jlong handle, jstring device) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- const char* cdevice = env->GetStringUTFChars(device, nullptr);
- TF_SetDevice(d, cdevice);
- env->ReleaseStringUTFChars(device, cdevice);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString(
- JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) {
- static_assert(sizeof(jbyte) == 1,
- "Require Java byte to be represented as a single byte");
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- const char* cname = env->GetStringUTFChars(name, nullptr);
- jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
- TF_SetAttrString(d, cname, cvalue, env->GetArrayLength(value));
- env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
- env->ReleaseStringUTFChars(name, cname);
-}
-
-#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
- JNIEXPORT void JNICALL \
- Java_org_tensorflow_GraphOperationBuilder_setAttr##name( \
- JNIEnv* env, jclass clazz, jlong handle, jstring name, \
- jtype value) { \
- static_assert( \
- sizeof(ctype) >= sizeof(jtype), \
- "Information loss when converting between Java and C types"); \
- TF_OperationDescription* d = requireHandle(env, handle); \
- if (d == nullptr) return; \
- const char* cname = env->GetStringUTFChars(name, nullptr); \
- TF_SetAttr##name(d, cname, static_cast(value)); \
- env->ReleaseStringUTFChars(name, cname); \
- }
-
-#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
- JNIEXPORT void JNICALL \
- Java_org_tensorflow_GraphOperationBuilder_setAttr##name##List( \
- JNIEnv* env, jclass clazz, jlong handle, jstring name, \
- jtype##Array value) { \
- TF_OperationDescription* d = requireHandle(env, handle); \
- if (d == nullptr) return; \
- const char* cname = env->GetStringUTFChars(name, nullptr); \
- /* Make a copy of the array to paper over any differences */ \
- /* in byte representations of the jtype and ctype */ \
- /* For example, jint vs TF_DataType. */ \
- /* If this copy turns out to be a problem in practice */ \
- /* can avoid it for many types. */ \
- const int n = env->GetArrayLength(value); \
- std::unique_ptr cvalue(new ctype[n]); \
- jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
- for (int i = 0; i < n; ++i) { \
- cvalue[i] = static_cast(elems[i]); \
- } \
- TF_SetAttr##name##List(d, cname, cvalue.get(), n); \
- env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
- env->ReleaseStringUTFChars(name, cname); \
- }
-
-#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
- DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
- DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
-
-DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
-DEFINE_SET_ATTR(Float, Float, jfloat, float);
-DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
-DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
-#undef DEFINE_SET_ATTR
-#undef DEFINE_SET_ATTR_LIST
-#undef DEFINE_SET_ATTR_SCALAR
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(
- JNIEnv* env, jclass clazz, jlong handle, jstring name,
- jlong tensor_handle) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- TF_Tensor* t = requireTensor(env, tensor_handle);
- if (t == nullptr) return;
- const char* cname = env->GetStringUTFChars(name, nullptr);
- TF_Status* status = TF_NewStatus();
- TF_SetAttrTensor(d, cname, t, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(name, cname);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(
- JNIEnv* env, jclass clazz, jlong handle, jstring name,
- jlongArray tensor_handles) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- const int n = env->GetArrayLength(tensor_handles);
- std::unique_ptr tensors(new TF_Tensor*[n]);
- jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr);
- bool ok = true;
- for (int i = 0; i < n && ok; ++i) {
- tensors[i] = requireTensor(env, jhandles[i]);
- ok = !env->ExceptionCheck();
- }
- env->ReleaseLongArrayElements(tensor_handles, jhandles, JNI_ABORT);
- if (!ok) return;
-
- const char* cname = env->GetStringUTFChars(name, nullptr);
- TF_Status* status = TF_NewStatus();
- TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- env->ReleaseStringUTFChars(name, cname);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape(
- JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape,
- jint num_dims) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- std::unique_ptr cvalue;
- // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
- // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
- if (num_dims > 0) {
- cvalue.reset(new int64_t[num_dims]);
- jlong* elems = env->GetLongArrayElements(shape, nullptr);
- for (int i = 0; i < num_dims; ++i) {
- cvalue[i] = static_cast(elems[i]);
- }
- env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
- }
- const char* cname = env->GetStringUTFChars(name, nullptr);
- TF_SetAttrShape(d, cname, cvalue.get(), static_cast(num_dims));
- env->ReleaseStringUTFChars(name, cname);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(
- JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes,
- jintArray num_dims) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- std::unique_ptr cshapes;
- std::unique_ptr cdims;
- std::unique_ptr cnum_dims;
- const int num_dims_length = env->GetArrayLength(num_dims);
- if (num_dims_length > 0) {
- const int shapes_length = env->GetArrayLength(shapes);
- cshapes.reset(new int64_t[shapes_length]);
- cdims.reset(new int64_t*[num_dims_length]);
- cnum_dims.reset(new int[num_dims_length]);
- jlong* shapes_elems =
- static_cast(env->GetPrimitiveArrayCritical(shapes, nullptr));
- std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
- env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
- int64_t* cshapes_ptr = cshapes.get();
- jint* num_dims_elems =
- static_cast(env->GetPrimitiveArrayCritical(num_dims, nullptr));
- for (int i = 0; i < num_dims_length; ++i) {
- cnum_dims[i] = static_cast(num_dims_elems[i]);
- cdims[i] = cshapes_ptr;
- if (cnum_dims[i] > 0) {
- cshapes_ptr += cnum_dims[i];
- }
- }
- env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
- }
- const char* cname = env->GetStringUTFChars(name, nullptr);
- TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length);
- env->ReleaseStringUTFChars(name, cname);
-}
-
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(
- JNIEnv* env, jclass object, jlong handle, jstring name,
- jobjectArray values) {
- TF_OperationDescription* d = requireHandle(env, handle);
- if (d == nullptr) return;
- const char* cname = env->GetStringUTFChars(name, nullptr);
- int num_values = env->GetArrayLength(values);
- static_assert(sizeof(jbyte) == 1,
- "Require Java byte to be represented as a single byte");
- std::unique_ptr jarrays(new jbyteArray[num_values]);
- std::unique_ptr jvalues(new jbyte*[num_values]);
- std::unique_ptr cvalues(new void*[num_values]);
- std::unique_ptr lengths(new size_t[num_values]);
-
- for (int i = 0; i < num_values; ++i) {
- jbyteArray v =
- static_cast(env->GetObjectArrayElement(values, i));
- jarrays[i] = v;
- jvalues[i] = env->GetByteArrayElements(v, nullptr);
- cvalues[i] = jvalues[i];
- lengths[i] = static_cast(env->GetArrayLength(v));
- }
- TF_SetAttrStringList(d, cname, cvalues.get(), lengths.get(), num_values);
- for (int i = 0; i < num_values; ++i) {
- env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
- }
- env->ReleaseStringUTFChars(name, cname);
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.h
deleted file mode 100644
index fe76fcf28e7..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_builder_jni.h
+++ /dev/null
@@ -1,202 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: allocate
- * Signature: (JLjava/lang/String;Ljava/lang/String;)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate(
- JNIEnv *, jclass, jlong, jstring, jstring);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: finish
- * Signature: (J)J
- */
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_GraphOperationBuilder_finish(JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: addInput
- * Signature: (JJI)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput(
- JNIEnv *, jclass, jlong, jlong, jint);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: addInputList
- * Signature: (J[J[I)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList(
- JNIEnv *, jclass, jlong, jlongArray, jintArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: addControlInput
- * Signature: (JJ)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv *, jclass,
- jlong, jlong);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setDevice
- * Signature: (JLjava/lang/String;)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice(
- JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrString
- * Signature: (JLjava/lang/String;[B)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString(
- JNIEnv *, jclass, jlong, jstring, jbyteArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrInt
- * Signature: (JLjava/lang/String;J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrInt(
- JNIEnv *, jclass, jlong, jstring, jlong);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrIntList
- * Signature: (JLjava/lang/String;[J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrIntList(
- JNIEnv *, jclass, jlong, jstring, jlongArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrFloat
- * Signature: (JLjava/lang/String;F)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrFloat(
- JNIEnv *, jclass, jlong, jstring, jfloat);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrFloatList
- * Signature: (JLjava/lang/String;[F)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrFloatList(JNIEnv *, jclass,
- jlong, jstring,
- jfloatArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrBool
- * Signature: (JLjava/lang/String;Z)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrBool(
- JNIEnv *, jclass, jlong, jstring, jboolean);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrBoolList
- * Signature: (JLjava/lang/String;[Z)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrBoolList(JNIEnv *, jclass,
- jlong, jstring,
- jbooleanArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrType
- * Signature: (JLjava/lang/String;I)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrType(
- JNIEnv *, jclass, jlong, jstring, jint);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrTypeList
- * Signature: (JLjava/lang/String;[I)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrTypeList(JNIEnv *, jclass,
- jlong, jstring,
- jintArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrTensor
- * Signature: (JLjava/lang/String;J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(
- JNIEnv *, jclass, jlong, jstring, jlong);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrTensorList
- * Signature: (JLjava/lang/String;[J)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(JNIEnv *, jclass,
- jlong, jstring,
- jlongArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrShape
- * Signature: (JLjava/lang/String;[JI)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape(
- JNIEnv *, jclass, jlong, jstring, jlongArray, jint);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrShapeList
- * Signature: (JLjava/lang/String;[J[I)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(JNIEnv *, jclass,
- jlong, jstring,
- jlongArray,
- jintArray);
-
-/*
- * Class: org_tensorflow_GraphOperationBuilder
- * Method: setAttrStringList
- * Signature: (JLjava/lang/String;[L)V
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(JNIEnv *, jclass,
- jlong, jstring,
- jobjectArray);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.cc
deleted file mode 100644
index f5860f7bf9a..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.cc
+++ /dev/null
@@ -1,166 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/graph_operation_jni.h"
-#include
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-namespace {
-template
-T* requireHandleImpl(JNIEnv* env, jlong handle) {
- static_assert(sizeof(jlong) >= sizeof(T*),
- "Cannot package C object pointers as a Java long");
- if (handle == 0) {
- throwException(
- env, kNullPointerException,
- "close() has been called on the Graph this Operation was a part of");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-TF_Operation* requireHandle(JNIEnv* env, jlong handle) {
- return requireHandleImpl(env, handle);
-}
-
-TF_Graph* requireGraphHandle(JNIEnv* env, jlong handle) {
- return requireHandleImpl(env, handle);
-}
-} // namespace
-
-JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_name(
- JNIEnv* env, jclass clazz, jlong handle) {
- TF_Operation* op = requireHandle(env, handle);
- if (op == nullptr) return nullptr;
- return env->NewStringUTF(TF_OperationName(op));
-}
-
-JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_type(
- JNIEnv* env, jclass clazz, jlong handle) {
- TF_Operation* op = requireHandle(env, handle);
- if (op == nullptr) return nullptr;
- return env->NewStringUTF(TF_OperationOpType(op));
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_numOutputs(
- JNIEnv* env, jclass clazz, jlong handle) {
- TF_Operation* op = requireHandle(env, handle);
- if (op == nullptr) return 0;
- return TF_OperationNumOutputs(op);
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_outputListLength(
- JNIEnv* env, jclass clazz, jlong handle, jstring name) {
- TF_Operation* op = requireHandle(env, handle);
- if (op == nullptr) return 0;
-
- TF_Status* status = TF_NewStatus();
-
- const char* cname = env->GetStringUTFChars(name, nullptr);
- int result = TF_OperationOutputListLength(op, cname, status);
- env->ReleaseStringUTFChars(name, cname);
-
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- return result;
-}
-
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_GraphOperation_shape(
- JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
- jint output_index) {
- TF_Graph* graph = requireGraphHandle(env, graph_handle);
- if (graph == nullptr) return nullptr;
- TF_Operation* op = requireHandle(env, op_handle);
- if (op == nullptr) return nullptr;
-
- int num_outputs = TF_OperationNumOutputs(op);
- if (output_index < 0 || output_index >= num_outputs) {
- throwException(
- env, kIndexOutOfBoundsException,
- "invalid output index (%d) for an operation that has %d outputs",
- output_index, num_outputs);
- return nullptr;
- }
-
- TF_Output output{op, output_index};
- TF_Status* status = TF_NewStatus();
- jsize num_dims = TF_GraphGetTensorNumDims(graph, output, status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return nullptr;
- }
- if (num_dims < 0) return nullptr;
- static_assert(sizeof(jlong) == sizeof(int64_t),
- "Java long is not compatible with the TensorFlow C API");
- // One might have trivially wanted to do:
- // TF_GraphGetTensorShape(graph, output, static_cast(dims), ...)
- // but on some platforms this fails with:
- // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
- // *') is not allowed
- // For now, do the expensive but safe thing of copying.
- std::unique_ptr cdims(new int64_t[num_dims]);
- TF_GraphGetTensorShape(graph, output, cdims.get(), static_cast(num_dims),
- status);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return nullptr;
- }
- TF_DeleteStatus(status);
-
- jlongArray ret = env->NewLongArray(num_dims);
- jlong* dims = env->GetLongArrayElements(ret, nullptr);
- for (int i = 0; i < num_dims; ++i) {
- dims[i] = static_cast(cdims[i]);
- }
- env->ReleaseLongArrayElements(ret, dims, 0);
- return ret;
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_dtype(
- JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
- jint output_index) {
- TF_Graph* graph = requireGraphHandle(env, graph_handle);
- if (graph == nullptr) return 0;
- TF_Operation* op = requireHandle(env, op_handle);
- if (op == nullptr) return 0;
-
- int num_outputs = TF_OperationNumOutputs(op);
- if (output_index < 0 || output_index >= num_outputs) {
- throwException(
- env, kIndexOutOfBoundsException,
- "invalid output index (%d) for an operation that has %d outputs",
- output_index, num_outputs);
- return 0;
- }
-
- return static_cast(TF_OperationOutputType(TF_Output{op, output_index}));
-}
-
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_inputListLength(
- JNIEnv* env, jclass clazz, jlong handle, jstring name) {
- TF_Operation* op = requireHandle(env, handle);
- if (op == nullptr) return 0;
-
- TF_Status* status = TF_NewStatus();
-
- const char* cname = env->GetStringUTFChars(name, nullptr);
- int result = TF_OperationInputListLength(op, cname, status);
- env->ReleaseStringUTFChars(name, cname);
-
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
- return result;
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.h
deleted file mode 100644
index bad4ada9cea..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/graph_operation_jni.h
+++ /dev/null
@@ -1,88 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: name
- * Signature: (J)Ljava/lang/String;
- */
-JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_name(JNIEnv *,
- jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: type
- * Signature: (J)Ljava/lang/String;
- */
-JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_type(JNIEnv *,
- jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: numOutputs
- * Signature: (J)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_numOutputs(JNIEnv *,
- jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: outputListLength
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_outputListLength(
- JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: shape
- * Signature: (JJI)[J
- */
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_GraphOperation_shape(JNIEnv *, jclass, jlong, jlong, jint);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: dtype
- * Signature: (JJI)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_dtype(JNIEnv *,
- jclass, jlong,
- jlong, jint);
-
-/*
- * Class: org_tensorflow_GraphOperation
- * Method: inputListLength
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_inputListLength(
- JNIEnv *, jclass, jlong, jstring);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.cc
deleted file mode 100644
index b0158ce151c..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include
-#include
-
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-#include "src/main/native/saved_model_bundle_jni.h"
-
-JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
- JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
- jbyteArray config, jbyteArray run_options) {
- TF_Status* status = TF_NewStatus();
- jobject bundle = nullptr;
-
- // allocate parameters for TF_LoadSessionFromSavedModel
- TF_SessionOptions* opts = TF_NewSessionOptions();
- if (config != nullptr) {
- size_t sz = env->GetArrayLength(config);
- if (sz > 0) {
- jbyte* config_data = env->GetByteArrayElements(config, nullptr);
- TF_SetConfig(opts, static_cast(config_data), sz, status);
- env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteSessionOptions(opts);
- TF_DeleteStatus(status);
- return nullptr;
- }
- }
- }
- TF_Buffer* crun_options = nullptr;
- if (run_options != nullptr) {
- size_t sz = env->GetArrayLength(run_options);
- if (sz > 0) {
- jbyte* run_options_data = env->GetByteArrayElements(run_options, nullptr);
- crun_options =
- TF_NewBufferFromString(static_cast(run_options_data), sz);
- env->ReleaseByteArrayElements(run_options, run_options_data, JNI_ABORT);
- }
- }
- const char* cexport_dir = env->GetStringUTFChars(export_dir, nullptr);
- std::unique_ptr tags_ptrs;
- size_t tags_len = env->GetArrayLength(tags);
- tags_ptrs.reset(new const char*[tags_len]);
- for (size_t i = 0; i < tags_len; ++i) {
- jstring tag = static_cast(env->GetObjectArrayElement(tags, i));
- tags_ptrs[i] = env->GetStringUTFChars(tag, nullptr);
- env->DeleteLocalRef(tag);
- }
-
- // load the session
- TF_Graph* graph = TF_NewGraph();
- TF_Buffer* metagraph_def = TF_NewBuffer();
- TF_Session* session = TF_LoadSessionFromSavedModel(
- opts, crun_options, cexport_dir, tags_ptrs.get(), tags_len, graph,
- metagraph_def, status);
-
- // release the parameters
- TF_DeleteSessionOptions(opts);
- if (crun_options != nullptr) {
- TF_DeleteBuffer(crun_options);
- }
- env->ReleaseStringUTFChars(export_dir, cexport_dir);
- for (size_t i = 0; i < tags_len; ++i) {
- jstring tag = static_cast(env->GetObjectArrayElement(tags, i));
- env->ReleaseStringUTFChars(tag, tags_ptrs[i]);
- env->DeleteLocalRef(tag);
- }
-
- // handle the result
- if (throwExceptionIfNotOK(env, status)) {
- // sizeof(jsize) is less than sizeof(size_t) on some platforms.
- if (metagraph_def->length > std::numeric_limits::max()) {
- throwException(
- env, kIndexOutOfBoundsException,
- "MetaGraphDef is too large to serialize into a byte[] array");
- } else {
- static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
- jint jmetagraph_len = static_cast(metagraph_def->length);
- jbyteArray jmetagraph_def = env->NewByteArray(jmetagraph_len);
- env->SetByteArrayRegion(jmetagraph_def, 0, jmetagraph_len,
- static_cast(metagraph_def->data));
-
- jmethodID method = env->GetStaticMethodID(
- clazz, "fromHandle", "(JJ[B)Lorg/tensorflow/SavedModelBundle;");
- bundle = env->CallStaticObjectMethod(
- clazz, method, reinterpret_cast(graph),
- reinterpret_cast(session), jmetagraph_def);
- graph = nullptr;
- session = nullptr;
- env->DeleteLocalRef(jmetagraph_def);
- }
- }
-
- if (session != nullptr) {
- TF_CloseSession(session, status);
- // Result of close is ignored, delete anyway.
- TF_DeleteSession(session, status);
- }
- if (graph != nullptr) {
- TF_DeleteGraph(graph);
- }
- TF_DeleteBuffer(metagraph_def);
- TF_DeleteStatus(status);
-
- return bundle;
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.h
deleted file mode 100644
index e8f28dd670d..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/saved_model_bundle_jni.h
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_SavedModelBundle
- * Method: load
- * Signature:
- * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
- */
-JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
- JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.cc
deleted file mode 100644
index b3ca3bcf053..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.cc
+++ /dev/null
@@ -1,104 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/server_jni.h"
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-#include "src/main/native/utils_jni.h"
-
-namespace {
-
-TF_Server* requireHandle(JNIEnv* env, jlong handle) {
- static_assert(sizeof(jlong) >= sizeof(TF_Server*),
- "Cannot package C object pointers as a Java long");
- if (handle == 0) {
- throwException(env, kIllegalStateException,
- "close() has been called on the Server");
- return nullptr;
- }
-
- return reinterpret_cast(handle);
-}
-
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Server_allocate(
- JNIEnv* env, jclass clazz, jbyteArray server_def) {
- TF_Status* status = TF_NewStatus();
-
- jbyte* server_def_ptr = env->GetByteArrayElements(server_def, nullptr);
-
- TF_Server* server = TF_NewServer(
- server_def_ptr, static_cast(env->GetArrayLength(server_def)),
- status);
-
- env->ReleaseByteArrayElements(server_def, server_def_ptr, JNI_ABORT);
- bool ok = throwExceptionIfNotOK(env, status);
-
- TF_DeleteStatus(status);
-
- return ok ? reinterpret_cast(server) : 0;
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_start(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- TF_Server* server = requireHandle(env, handle);
- if (server == nullptr) return;
-
- TF_Status* status = TF_NewStatus();
-
- TF_ServerStart(server, status);
- throwExceptionIfNotOK(env, status);
-
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_stop(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- TF_Server* server = requireHandle(env, handle);
- if (server == nullptr) return;
-
- TF_Status* status = TF_NewStatus();
-
- TF_ServerStop(server, status);
- throwExceptionIfNotOK(env, status);
-
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_join(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- TF_Server* server = requireHandle(env, handle);
- if (server == nullptr) return;
-
- TF_Status* status = TF_NewStatus();
-
- TF_ServerJoin(server, status);
- throwExceptionIfNotOK(env, status);
-
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_delete(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- TF_Server* server = requireHandle(env, handle);
- if (server == nullptr) return;
-
- TF_DeleteServer(server);
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.h
deleted file mode 100644
index 4bfe90b7a85..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/server_jni.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SERVER_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SERVER_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_Server
- * Method: allocate
- * Signature: ([B)J
- */
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_Server_allocate(JNIEnv *, jclass, jbyteArray server_def);
-
-/*
- * Class: org_tensorflow_Server
- * Method: start
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_start(JNIEnv *, jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_Server
- * Method: stop
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_stop(JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_tensorflow_Session
- * Method: join
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_join(JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_tensorflow_Session
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Server_delete(JNIEnv *, jclass,
- jlong);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SERVER_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.cc
deleted file mode 100644
index 8df682330b5..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.cc
+++ /dev/null
@@ -1,203 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include
-#include
-
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/utils_jni.h"
-#include "src/main/native/exception_jni.h"
-#include "src/main/native/session_jni.h"
-
-namespace {
-TF_Session* requireHandle(JNIEnv* env, jlong handle) {
- static_assert(sizeof(jlong) >= sizeof(TF_Session*),
- "Cannot package C object pointers as a Java long");
- if (handle == 0) {
- throwException(env, kNullPointerException,
- "close() has been called on the Session");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-template
-void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
- T** dst, jint n) {
- if (env->ExceptionCheck()) return;
- jint len = env->GetArrayLength(src_array);
- if (len != n) {
- throwException(env, kIllegalArgumentException, "expected %d, got %d %s", n,
- len, type);
- return;
- }
- jlong* src_start = env->GetLongArrayElements(src_array, nullptr);
- jlong* src = src_start;
- for (int i = 0; i < n; ++i, ++src, ++dst) {
- if (*src == 0) {
- throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
- i, n);
- break;
- }
- *dst = reinterpret_cast(*src);
- }
- env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
-}
-
-void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
- if (buf == nullptr) return;
- TF_DeleteBuffer(buf);
-}
-
-typedef std::unique_ptr
- unique_tf_buffer;
-
-unique_tf_buffer MakeUniqueBuffer(TF_Buffer* buf) {
- return unique_tf_buffer(buf, TF_MaybeDeleteBuffer);
-}
-
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(
- JNIEnv* env, jclass clazz, jlong graph_handle) {
- return Java_org_tensorflow_Session_allocate2(env, clazz, graph_handle,
- nullptr, nullptr);
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
- JNIEnv* env, jclass clazz, jlong graph_handle, jstring target,
- jbyteArray config) {
- if (graph_handle == 0) {
- throwException(env, kNullPointerException, "Graph has been close()d");
- return 0;
- }
- TF_Graph* graph = reinterpret_cast(graph_handle);
- TF_Status* status = TF_NewStatus();
- TF_SessionOptions* opts = TF_NewSessionOptions();
- jbyte* cconfig = nullptr;
- if (config != nullptr) {
- cconfig = env->GetByteArrayElements(config, nullptr);
- TF_SetConfig(opts, cconfig,
- static_cast(env->GetArrayLength(config)), status);
- if (!throwExceptionIfNotOK(env, status)) {
- env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
- TF_DeleteSessionOptions(opts);
- TF_DeleteStatus(status);
- return 0;
- }
- }
- const char* ctarget = nullptr;
- if (target != nullptr) {
- ctarget = env->GetStringUTFChars(target, nullptr);
- }
- TF_Session* session = TF_NewSession(graph, opts, status);
- if (config != nullptr) {
- env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
- }
- if (target != nullptr) {
- env->ReleaseStringUTFChars(target, ctarget);
- }
- TF_DeleteSessionOptions(opts);
- bool ok = throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
-
- return ok ? reinterpret_cast(session) : 0;
-}
-
-JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv* env,
- jclass clazz,
- jlong handle) {
- TF_Session* session = requireHandle(env, handle);
- if (session == nullptr) return;
- TF_Status* status = TF_NewStatus();
- TF_CloseSession(session, status);
- // Result of close is ignored, delete anyway.
- TF_DeleteSession(session, status);
- throwExceptionIfNotOK(env, status);
- TF_DeleteStatus(status);
-}
-
-JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
- JNIEnv* env, jclass clazz, jlong handle, jbyteArray jrun_options,
- jlongArray input_tensor_handles, jlongArray input_op_handles,
- jintArray input_op_indices, jlongArray output_op_handles,
- jintArray output_op_indices, jlongArray target_op_handles,
- jboolean want_run_metadata, jlongArray output_tensor_handles) {
- TF_Session* session = requireHandle(env, handle);
- if (session == nullptr) return nullptr;
-
- const jint ninputs = env->GetArrayLength(input_tensor_handles);
- const jint noutputs = env->GetArrayLength(output_tensor_handles);
- const jint ntargets = env->GetArrayLength(target_op_handles);
-
- std::unique_ptr inputs(new TF_Output[ninputs]);
- std::unique_ptr input_values(new TF_Tensor*[ninputs]);
- std::unique_ptr outputs(new TF_Output[noutputs]);
- std::unique_ptr output_values(new TF_Tensor*[noutputs]);
- std::unique_ptr targets(new TF_Operation*[ntargets]);
- unique_tf_buffer run_metadata(
- MakeUniqueBuffer(want_run_metadata ? TF_NewBuffer() : nullptr));
-
- resolveHandles(env, "input Tensors", input_tensor_handles, input_values.get(),
- ninputs);
- resolveOutputs(env, "input", input_op_handles, input_op_indices, inputs.get(),
- ninputs);
- resolveOutputs(env, "output", output_op_handles, output_op_indices,
- outputs.get(), noutputs);
- resolveHandles(env, "target Operations", target_op_handles, targets.get(),
- ntargets);
- if (env->ExceptionCheck()) return nullptr;
-
- TF_Status* status = TF_NewStatus();
-
- unique_tf_buffer run_options(MakeUniqueBuffer(nullptr));
- jbyte* jrun_options_data = nullptr;
- if (jrun_options != nullptr) {
- size_t sz = env->GetArrayLength(jrun_options);
- if (sz > 0) {
- jrun_options_data = env->GetByteArrayElements(jrun_options, nullptr);
- run_options.reset(
- TF_NewBufferFromString(static_cast(jrun_options_data), sz));
- }
- }
-
- TF_SessionRun(session, run_options.get(), inputs.get(), input_values.get(),
- static_cast(ninputs), outputs.get(), output_values.get(),
- static_cast(noutputs), targets.get(),
- static_cast(ntargets), run_metadata.get(), status);
-
- if (jrun_options_data != nullptr) {
- env->ReleaseByteArrayElements(jrun_options, jrun_options_data, JNI_ABORT);
- }
-
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return nullptr;
- }
- jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
- for (int i = 0; i < noutputs; ++i) {
- t[i] = reinterpret_cast(output_values[i]);
- }
- env->ReleaseLongArrayElements(output_tensor_handles, t, 0);
-
- jbyteArray ret = nullptr;
- if (run_metadata != nullptr) {
- ret = env->NewByteArray(run_metadata->length);
- env->SetByteArrayRegion(ret, 0, run_metadata->length,
- reinterpret_cast(run_metadata->data));
- }
- TF_DeleteStatus(status);
- return ret;
-}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.h b/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.h
deleted file mode 100644
index 1cc196bdc8a..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/session_jni.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
-#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
-
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*
- * Class: org_tensorflow_Session
- * Method: allocate
- * Signature: (J)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(JNIEnv *, jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_Session
- * Method: allocate2
- * Signature: (JLjava/lang/String;[B)J
- */
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(JNIEnv *, jclass,
- jlong, jstring,
- jbyteArray);
-
-/*
- * Class: org_tensorflow_Session
- * Method: delete
- * Signature: (J)V
- */
-JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv *, jclass,
- jlong);
-
-/*
- * Class: org_tensorflow_Session
- * Method: run
- * Signature: (J[B[J[J[I[J[I[JZ[J)[B
- */
-JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
- JNIEnv *, jclass, jlong, jbyteArray, jlongArray, jlongArray, jintArray,
- jlongArray, jintArray, jlongArray, jboolean, jlongArray);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
diff --git a/tensorflow-core/tensorflow-core-api/src/main/native/tensor_jni.cc b/tensorflow-core/tensorflow-core-api/src/main/native/tensor_jni.cc
deleted file mode 100644
index fe32637eecc..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/native/tensor_jni.cc
+++ /dev/null
@@ -1,623 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "src/main/native/tensor_jni.h"
-
-#include
-#include
-#include
-#include
-#include
-
-#include "tensorflow/c/c_api.h"
-#include "src/main/native/exception_jni.h"
-
-namespace {
-
-TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
- if (handle == 0) {
- throwException(env, kNullPointerException,
- "close() was called on the Tensor");
- return nullptr;
- }
- return reinterpret_cast(handle);
-}
-
-size_t elemByteSize(TF_DataType dtype) {
- // The code in this file makes the assumption that the
- // TensorFlow TF_DataTypes and the Java primitive types
- // have the same byte sizes. Validate that:
- switch (dtype) {
- case TF_BOOL:
- case TF_UINT8:
- static_assert(sizeof(jboolean) == 1,
- "Java boolean not compatible with TF_BOOL");
- static_assert(sizeof(jbyte) == 1,
- "Java byte not compatible with TF_UINT8");
- return 1;
- case TF_FLOAT:
- case TF_INT32:
- static_assert(sizeof(jfloat) == 4,
- "Java float not compatible with TF_FLOAT");
- static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
- return 4;
- case TF_DOUBLE:
- case TF_INT64:
- static_assert(sizeof(jdouble) == 8,
- "Java double not compatible with TF_DOUBLE");
- static_assert(sizeof(jlong) == 8,
- "Java long not compatible with TF_INT64");
- return 8;
- default:
- return 0;
- }
-}
-
-// Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
-void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
- size_t dst_size) {
- size_t sz = elemByteSize(dtype);
- if (sz != dst_size) {
- throwException(
- env, kIllegalStateException,
- "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
- dst_size);
- return;
- }
- switch (dtype) {
-// env->FindClass and env->GetMethodID are expensive and JNI best practices
-// suggest that they should be cached. However, until the creation of scalar
-// valued tensors seems to become a noticeable fraction of program execution,
-// ignore that cost.
-#define CASE(dtype, jtype, method_name, method_signature, call_type) \
- case dtype: { \
- jclass clazz = env->FindClass("java/lang/Number"); \
- jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
- jtype v = env->Call##call_type##Method(src, method); \
- memcpy(dst, &v, sz); \
- return; \
- }
- CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
- CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
- CASE(TF_INT32, jint, "intValue", "()I", Int);
- CASE(TF_INT64, jlong, "longValue", "()J", Long);
- CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
-#undef CASE
- case TF_BOOL: {
- jclass clazz = env->FindClass("java/lang/Boolean");
- jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
- jboolean v = env->CallBooleanMethod(src, method);
- *(static_cast(dst)) = v ? 1 : 0;
- return;
- }
- default:
- throwException(env, kIllegalStateException, "invalid DataType(%d)",
- dtype);
- return;
- }
-}
-
-// Copy a 1-D array of Java primitive types to the tensor buffer dst.
-// Returns the number of bytes written to dst.
-size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
- size_t dst_size) {
- const int nelems = env->GetArrayLength(array);
- jboolean is_copy;
- switch (dtype) {
-#define CASE(dtype, jtype, get_type) \
- case dtype: { \
- jtype##Array a = static_cast(array); \
- jtype* values = env->Get##get_type##ArrayElements(a, &is_copy); \
- size_t to_copy = nelems * elemByteSize(dtype); \
- if (to_copy > dst_size) { \
- throwException( \
- env, kIllegalStateException, \
- "cannot write Java array of %d bytes to Tensor of %d bytes", \
- to_copy, dst_size); \
- to_copy = 0; \
- } else { \
- memcpy(dst, values, to_copy); \
- } \
- env->Release##get_type##ArrayElements(a, values, JNI_ABORT); \
- return to_copy; \
- }
- CASE(TF_FLOAT, jfloat, Float);
- CASE(TF_DOUBLE, jdouble, Double);
- CASE(TF_INT32, jint, Int);
- CASE(TF_INT64, jlong, Long);
- CASE(TF_BOOL, jboolean, Boolean);
- CASE(TF_UINT8, jbyte, Byte);
-#undef CASE
- default:
- throwException(env, kIllegalStateException, "invalid DataType(%d)",
- dtype);
- return 0;
- }
-}
-
-// Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
-// Java primitive types. Returns the number of bytes read from src.
-size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
- size_t src_size, jarray dst) {
- const int len = env->GetArrayLength(dst);
- const size_t sz = len * elemByteSize(dtype);
- if (sz > src_size) {
- throwException(
- env, kIllegalStateException,
- "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
- src_size);
- return 0;
- }
- switch (dtype) {
-#define CASE(dtype, jtype, primitive_type) \
- case dtype: { \
- jtype##Array arr = static_cast(dst); \
- env->Set##primitive_type##ArrayRegion(arr, 0, len, \
- static_cast(src)); \
- return sz; \
- }
- CASE(TF_FLOAT, jfloat, Float);
- CASE(TF_DOUBLE, jdouble, Double);
- CASE(TF_INT32, jint, Int);
- CASE(TF_INT64, jlong, Long);
- CASE(TF_BOOL, jboolean, Boolean);
- CASE(TF_UINT8, jbyte, Byte);
-#undef CASE
- default:
- throwException(env, kIllegalStateException, "invalid DataType(%d)",
- dtype);
- }
- return 0;
-}
-
-size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
- char* dst, size_t dst_size) {
- if (dims_left == 1) {
- return write1DArray(env, src, dtype, dst, dst_size);
- } else {
- jobjectArray ndarray = static_cast(src);
- int len = env->GetArrayLength(ndarray);
- size_t sz = 0;
- for (int i = 0; i < len; ++i) {
- jarray row = static_cast(env->GetObjectArrayElement(ndarray, i));
- sz +=
- writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
- env->DeleteLocalRef(row);
- if (env->ExceptionCheck()) return sz;
- }
- return sz;
- }
-}
-
-size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
- size_t src_size, int dims_left, jarray dst) {
- if (dims_left == 1) {
- return read1DArray(env, dtype, src, src_size, dst);
- } else {
- jobjectArray ndarray = static_cast(dst);
- int len = env->GetArrayLength(ndarray);
- size_t sz = 0;
- for (int i = 0; i < len; ++i) {
- jarray row = static_cast(env->GetObjectArrayElement(ndarray, i));
- sz +=
- readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
- env->DeleteLocalRef(row);
- if (env->ExceptionCheck()) return sz;
- }
- return sz;
- }
-}
-
-jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
- size_t src_len, TF_Status* status) {
- const char* dst = nullptr;
- size_t dst_len = 0;
- TF_StringDecode(src, src_len, &dst, &dst_len, status);
- if (TF_GetCode(status) != TF_OK) {
- return nullptr;
- }
- jbyteArray ret = env->NewByteArray(dst_len);
- jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
- memcpy(cpy, dst, dst_len);
- env->ReleaseByteArrayElements(ret, cpy, 0);
- return ret;
-}
-
-class StringTensorWriter {
- public:
- StringTensorWriter(TF_Tensor* t, int num_elements)
- : offset_(0),
- poffsets_(static_cast(TF_TensorData(t))),
- pdata_(poffsets_ + 8 * num_elements),
- plimit_(poffsets_ + TF_TensorByteSize(t)) {}
-
- void Add(const char* src, size_t len, TF_Status* status) {
- if (TF_GetCode(status) != TF_OK) return;
- if (plimit_ - poffsets_ < sizeof(offset_)) {
- TF_SetStatus(status, TF_OUT_OF_RANGE,
- "TF_STRING tensor encoding ran out of space for offsets, "
- "this is likely a bug, please file an issue at "
- "https://github.com/tensorflow/tensorflow/issues/new");
- return;
- }
- memcpy(poffsets_, &offset_, sizeof(offset_));
- size_t written =
- TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
- offset_ += written;
- poffsets_ += 8;
- pdata_ += written;
- }
-
- private:
- uint64_t offset_;
- char* poffsets_;
- char* pdata_;
- const char* plimit_;
-};
-
-class StringTensorReader {
- public:
- StringTensorReader(const TF_Tensor* t, int num_elements)
- : index_(0),
- offsets_(static_cast(TF_TensorData(t))),
- data_(offsets_ + 8 * num_elements),
- limit_(offsets_ + TF_TensorByteSize(t)) {}
-
- jbyteArray Next(JNIEnv* env, TF_Status* status) {
- if (TF_GetCode(status) != TF_OK) return nullptr;
- uint64_t offset = 0;
- const char* poffset = offsets_ + sizeof(offset) * index_;
- if (poffset >= limit_) {
- TF_SetStatus(
- status, TF_INTERNAL,
- "Invalid TF_STRING tensor, offsets table seems to be too small");
- return nullptr;
- }
- memcpy(&offset, poffset, sizeof(offset));
- const char* pdata = data_ + offset;
- if (pdata >= limit_) {
- TF_SetStatus(status, TF_INTERNAL,
- "Invalid TF_STRING tensor, invalid entry in offset table");
- return nullptr;
- }
- ++index_;
- return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
- }
-
- private:
- int index_;
- const char* offsets_;
- const char* data_;
- const char* limit_;
-};
-
-void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
- jobjectArray dst, TF_Status* status) {
- jsize len = env->GetArrayLength(dst);
- if (dims_left == 1) {
- for (jsize i = 0; i < len; ++i) {
- jbyteArray elem = reader->Next(env, status);
- if (TF_GetCode(status) != TF_OK) return;
- env->SetObjectArrayElement(dst, i, elem);
- }
- return;
- }
- for (jsize i = 0; i < len; ++i) {
- jobjectArray arr =
- static_cast(env->GetObjectArrayElement(dst, i));
- readNDStringArray(env, reader, dims_left - 1, arr, status);
- if (TF_GetCode(status) != TF_OK) return;
- }
-}
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
- jclass clazz,
- jint dtype,
- jlongArray shape,
- jlong sizeInBytes) {
- int num_dims = static_cast(env->GetArrayLength(shape));
- jlong* dims = nullptr;
- if (num_dims > 0) {
- jboolean is_copy;
- dims = env->GetLongArrayElements(shape, &is_copy);
- }
- static_assert(sizeof(jlong) == sizeof(int64_t),
- "Java long is not compatible with the TensorFlow C API");
- // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
- //
- // Thus, static_cast(dims) will trigger a compiler error:
- // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
- // *') is not allowed
- //
- // Since this array is typically very small, use the guaranteed safe scheme of
- // creating a copy.
- int64_t* dims_copy = new int64_t[num_dims];
- for (int i = 0; i < num_dims; ++i) {
- dims_copy[i] = static_cast(dims[i]);
- }
- TF_Tensor* t = TF_AllocateTensor(static_cast(dtype), dims_copy,
- num_dims, static_cast(sizeInBytes));
- delete[] dims_copy;
- if (dims != nullptr) {
- env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
- }
- if (t == nullptr) {
- throwException(env, kNullPointerException,
- "unable to allocate memory for the Tensor");
- return 0;
- }
- return reinterpret_cast(t);
-}
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
- JNIEnv* env, jclass clazz, jbyteArray value) {
- // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
- // TF_StringEncode-encoded bytes.
- size_t src_len = static_cast(env->GetArrayLength(value));
- size_t dst_len = TF_StringEncodedSize(src_len);
- TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
- char* dst = static_cast(TF_TensorData(t));
- memset(dst, 0, 8); // The offset table
-
- TF_Status* status = TF_NewStatus();
- jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
- // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
- // reinterpret_cast<> for this conversion should be safe.
- TF_StringEncode(reinterpret_cast(jsrc), src_len, dst + 8,
- dst_len, status);
- env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
- if (!throwExceptionIfNotOK(env, status)) {
- TF_DeleteStatus(status);
- return 0;
- }
- TF_DeleteStatus(status);
- return reinterpret_cast(t);
-}
-
-namespace {
-size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
- if (num_dims == 0) {
- // This is the last dimension, i.e., value should correspond to a jbyteArray
- // encoding the string.
- return TF_StringEncodedSize(
- static_cast(env->GetArrayLength(value)));
- }
- jsize len = env->GetArrayLength(value);
- size_t ret = 0;
- for (jsize i = 0; i < len; ++i) {
- jarray elem = static_cast(
- env->GetObjectArrayElement(static_cast(value), i));
- if (elem == nullptr) {
- throwException(env, kNullPointerException,
- "null entries in provided array");
- return ret;
- }
- ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
- if (env->ExceptionCheck()) return ret;
- }
- return ret;
-}
-
-void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
- StringTensorWriter* writer,
- TF_Status* status) {
- if (num_dims == 0) {
- jbyte* jsrc =
- env->GetByteArrayElements(static_cast(value), nullptr);
- writer->Add(reinterpret_cast(jsrc), env->GetArrayLength(value),
- status);
- env->ReleaseByteArrayElements(static_cast(value), jsrc,
- JNI_ABORT);
- return;
- }
- jsize len = env->GetArrayLength(value);
- for (jsize i = 0; i < len; ++i) {
- jarray elem = static_cast(
- env->GetObjectArrayElement(static_cast(value), i));
- fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
- if (TF_GetCode(status) != TF_OK) return;
- }
-}
-} // namespace
-
-JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
- JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
- // TF_STRING tensors are encoded with a table of 8-byte offsets following by
- // TF_StringEncode-encoded bytes.
- const int num_dims = static_cast