Comments (7)
Hi @aday00 ,
I'm not sure why you are getting this error, it looks like your generated protos are out-of-sync somehow, maybe a mvn clean
before can do it? (be aware that will retrigger a full Bazel build though)
But the save_model
branch in my personal repo was just something very experimental showing what could be done and I was not really expecting supporting it for real usage. We just synced up on this topic during the SIG session and now we have a clear plan on how we want to design this before merging it to tensorflow/java
. This should happen in the next month or so.
from java.
Thank you kindly @karllessard !
It appears the SIG document is https://docs.google.com/document/d/1RqSIKKFLE8kMSjUuZWsVS4JNCRHIrCYiuWcFZRNk2oU/edit# and mentions this:
Saved models and function graphs
Shajan changes have been pushed to this branch
Interesting thread about this topic
We should build an API that is ready to:
Serve session-centric models (estimators and TF1.x)
Serve function-centric models (TF2.x)
Create function graphs, referenced as an attribute of an op (e.g. `tf.dataset.map`)
Create function graphs with signature (like `tf.function` does)
Should we save session-centric models from Java? C API only supports this “officially” for now
Confirmation: lets model the API with functions while still loading/saving session-centric graphs
It appears @Shajan's changes are discussed in #89 and the commit d845856
I don't see a unit test, so I'm not sure how to use this to save a model. Any tips would be greatly appreciated!
In the mean time, I'm happy to continue with the save_model
branch and SavedModelBundle exporter(...)
. Thanks for this pre-alpha feature, and all the time you've invested in the stewardship of this project! Hopefully I can contribute a bit by testing.
from java.
Yes, the portion for saving the model is not in this branch yet but work will start soon and might differ a bit from what you see in my save_model
branch right now because we will opt for a more "function-centric" API.
Keeping you posted! and please feel free to share any feedbacks on the draft version in my repo as they can also apply to the next version.
from java.
Thanks @karllessard ! Ran mvn clean
as suggested:
root@478b80e86d3b:tensorflow-java-karl# mvn clean
[INFO] Scanning for projects...
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Build Order:
[INFO]
[INFO] TensorFlow Java Parent [pom]
[INFO] TensorFlow Tools Library [jar]
[INFO] TensorFlow Core Parent [pom]
[INFO] TensorFlow Core Annotation Processor [jar]
[INFO] TensorFlow Core API Library [jar]
[INFO] TensorFlow Core API Library Platform [jar]
[INFO] TensorFlow Framework Library [jar]
[INFO]
[INFO] -------------------< org.tensorflow:tensorflow-java >-------------------
[INFO] Building TensorFlow Java Parent 0.1.0-SNAPSHOT [1/7]
[INFO] --------------------------------[ pom ]---------------------------------
...
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ tensorflow-framework ---
[INFO] Deleting /tmp/docker-share/tensorflow-java-karl/tensorflow-framework/target
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.1.0-SNAPSHOT:
[INFO]
[INFO] TensorFlow Java Parent ............................. SUCCESS [ 5.399 s]
[INFO] TensorFlow Tools Library ........................... SUCCESS [ 0.227 s]
[INFO] TensorFlow Core Parent ............................. SUCCESS [ 0.044 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [ 0.337 s]
[INFO] TensorFlow Core API Library ........................ SUCCESS [ 33.345 s]
[INFO] TensorFlow Core API Library Platform ............... SUCCESS [ 0.313 s]
[INFO] TensorFlow Framework Library ....................... SUCCESS [ 0.717 s]
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 41.954 s
[INFO] Finished at: 2020-08-14T19:25:24Z
[INFO] ------------------------------------------------------------------------
Clean worked, good. Next, mvn install
, but same error:
root@478b80e86d3b:tensorflow-java-karl# mvn install -e
[INFO] Error stacktraces are turned on.
[INFO] Scanning for projects...
...
[INFO] --- maven-compiler-plugin:3.8.0:compile (default-compile) @ tensorflow-core-api ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 1678 source files to /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/target/classes
[INFO] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java: Some input files use unchecked or unsafe operations.
[INFO] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java: Recompile with -Xlint:unchecked for details.
[INFO] -------------------------------------------------------------
[ERROR] COMPILATION ERROR :
[INFO] -------------------------------------------------------------
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol
symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[135,55] cannot find symbol
symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[479,57] cannot find symbol
symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[485,57] cannot find symbol
symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[536,57] cannot find symbol
symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[82,65] cannot find symbol
symbol: variable internal_static_tensorflow_SaveableObject_descriptor
location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[88,65] cannot find symbol
symbol: variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[290,67] cannot find symbol
symbol: variable internal_static_tensorflow_SaveableObject_descriptor
location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[296,67] cannot find symbol
symbol: variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[329,67] cannot find symbol
symbol: variable internal_static_tensorflow_SaveableObject_descriptor
location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[INFO] 10 errors
[INFO] -------------------------------------------------------------
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.1.0-SNAPSHOT:
[INFO]
[INFO] TensorFlow Java Parent ............................. SUCCESS [ 1.561 s]
[INFO] TensorFlow Tools Library ........................... SUCCESS [02:59 min]
[INFO] TensorFlow Core Parent ............................. SUCCESS [ 0.132 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [ 4.513 s]
[INFO] TensorFlow Core API Library ........................ FAILURE [ 15:00 h]
[INFO] TensorFlow Core API Library Platform ............... SKIPPED
[INFO] TensorFlow Framework Library ....................... SKIPPED
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 15:03 h
[INFO] Finished at: 2020-08-15T10:30:05Z
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.0:compile (default-compile) on project tensorflow-core-api: Compilation failure: Compilation failure:
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR] location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[135,55] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
[ERROR] location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[479,57] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR] location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[485,57] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
[ERROR] location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[536,57] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR] location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[82,65] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR] location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[88,65] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
[ERROR] location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[290,67] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR] location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[296,67] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
[ERROR] location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[329,67] cannot find symbol
[ERROR] symbol: variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR] location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] -> [Help 1]
From some SIGJVM email https://groups.google.com/a/tensorflow.org/d/msg/jvm/gGKO-hVS4Pc/LF4rLJOdAQAJ , it seems cherry-picking the save_model
patch/changes and applying to the mainline tensorflow-java might work? Thanks also for #101 !
from java.
The save_model commit is karllessard@bdb0420
That patch won't apply cleanly to mainline tensorflow-java, so I'll try to make the change by hand.
from java.
Applying the saved_model
-related changes by hand, to merge into the mainline tensorflow-java
, the build seems to work, and I will test this soon:
root@11ade3e6890f:tensorflow-java-savemodel# mvn install -e
...
[INFO] --- maven-install-plugin:2.4:install (default-install) @ tensorflow-framework ---
[INFO] Installing /tmp/docker-share/tensorflow-java-savemodel/tensorflow-framework/target/tensorflow-framework-0.2.0-SNAPSHOT.jar to /root/.m2/repository/org/tensorflow/tensorflow-framework/0.2.0-SNAPSHOT/tensorflow-framework-0.2.0-SNAPSHOT.jar
[INFO] Installing /tmp/docker-share/tensorflow-java-savemodel/tensorflow-framework/pom.xml to /root/.m2/repository/org/tensorflow/tensorflow-framework/0.2.0-SNAPSHOT/tensorflow-framework-0.2.0-SNAPSHOT.pom
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.2.0-SNAPSHOT:
[INFO]
[INFO] TensorFlow Java Parent ............................. SUCCESS [ 1.450 s]
[INFO] TensorFlow NdArray Library ......................... SUCCESS [02:26 min]
[INFO] TensorFlow Core Parent ............................. SUCCESS [ 0.055 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [ 3.259 s]
[INFO] TensorFlow Core API Library ........................ SUCCESS [07:41 min]
[INFO] TensorFlow Core API Library Platform ............... SUCCESS [ 0.873 s]
[INFO] TensorFlow Framework Library ....................... SUCCESS [ 38.209 s]
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 10:53 min
[INFO] Finished at: 2020-08-16T23:37:45Z
[INFO] ------------------------------------------------------------------------
My changes are almost identical to @karllessard's code karllessard@bdb0420. I sometimes made changes that did not alter the function of the code, but made the code more similar to Karl's branch, which has the save_model
code. If we want a branch and PR from my work, I am happy to clean this up further. Prior email suggested that others (e.g. Alexey Zinoviev) may want to test model-saving code like this too https://groups.google.com/a/tensorflow.org/d/msg/jvm/gGKO-hVS4Pc/D0z5yu-gAQAJ.
My changes are mostly with changes related to NdArray namespacing. The diffs are below, starting with Ops.java:
tensorflow-java-savemodel# diff -u tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
--- tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java 2020-08-16 19:30:40.402099828 -0400
+++ ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java 2020-08-09 19:39:33.350722050 -0400
@@ -312,10 +312,10 @@
public final ImageOps image;
- public final DataOps data;
-
public final ShapeOps shape;
+ public final DataOps data;
+
public final IoOps io;
public final DtypesOps dtypes;
@@ -338,10 +338,10 @@
public final SignalOps signal;
- public final QuantizationOps quantization;
-
public final TrainOps train;
+ public final QuantizationOps quantization;
+
private final Scope scope;
private Ops(Scope scope) {
@@ -349,8 +349,8 @@
nn = new NnOps(scope);
summary = new SummaryOps(scope);
image = new ImageOps(scope);
- data = new DataOps(scope);
shape = new ShapeOps(scope);
+ data = new DataOps(scope);
io = new IoOps(scope);
dtypes = new DtypesOps(scope);
xla = new XlaOps(scope);
@@ -362,8 +362,8 @@
math = new MathOps(scope);
audio = new AudioOps(scope);
signal = new SignalOps(scope);
- quantization = new QuantizationOps(scope);
train = new TrainOps(scope);
+ quantization = new QuantizationOps(scope);
}
/**
Graph.java:
tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java 2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java 2020-08-16 19:19:15.860833163 -0400
@@ -44,7 +44,16 @@
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_WhileParams;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Constant;
+import org.tensorflow.op.core.NoOp;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.train.Restore;
+import org.tensorflow.op.train.Save;
import org.tensorflow.proto.framework.GraphDef;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.types.TString;
/**
@@ -67,6 +76,11 @@
this.nativeHandle = nativeHandle;
}
+ Graph(TF_Graph nativeHandle, SaverDef saverDef) {
+ this(nativeHandle);
+ this.saverDef = saverDef;
+ }
+
/**
* Release resources associated with the Graph.
*
@@ -287,6 +301,17 @@
return addGradients(null, new Output<?>[] {y}, x, null);
}
+ public SaverDef saverDef() {
+ if (saverDef == null) {
+ synchronized (this) {
+ if (saverDef == null) {
+ saverDef = addVariableSaver(this);
+ }
+ }
+ }
+ return saverDef;
+ }
+
/**
* Used to instantiate an abstract class which overrides the buildSubgraph method to build a
* conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
@@ -405,6 +430,7 @@
private final Object nativeHandleLock = new Object();
private TF_Graph nativeHandle;
private int refcount = 0;
+ private SaverDef saverDef;
private final List<Op> initializers = new ArrayList<>();
@@ -726,6 +752,53 @@
}
}
+ private static SaverDef addVariableSaver(Graph graph) {
+ Ops tf = Ops.create(graph).withSubScope("save");
+
+ List<String> varNames = new ArrayList<>();
+ List<Operand<?>> varOutputs = new ArrayList<>();
+ List<DataType<?>> varTypes = new ArrayList<>();
+
+ for (Iterator<Operation> iter = graph.operations(); iter.hasNext();) {
+ Operation op = iter.next();
+ if (op.type().equals("VariableV2")) {
+ varNames.add(op.name());
+ varOutputs.add(op.output(0));
+ varTypes.add(op.output(0).dataType());
+ }
+ }
+
+ // FIXME Need an easier way to initialize an NdArray from a list
+ String[] tmp = new String[varNames.size()];
+ Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
+ Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
+
+ Placeholder<TString> saveFilename = tf.placeholder(TString.DTYPE);
+ Save saveVariables = tf.train.save(
+ saveFilename,
+ varNamesTensor,
+ varSlices,
+ varOutputs
+ );
+ Restore restoreVariables = tf.train.restore(
+ saveFilename,
+ varNamesTensor,
+ varSlices,
+ varTypes
+ );
+ List<Op> restoreOps = new ArrayList<>(varOutputs.size());
+ for (int i = 0; i < varOutputs.size(); ++i) {
+ restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
+ }
+ NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
+
+ return SaverDef.newBuilder()
+ .setFilenameTensorName(saveFilename.op().name())
+ .setSaveTensorName(saveVariables.op().name())
+ .setRestoreOpName(restoreAll.op().name())
+ .build();
+ }
+
static {
TensorFlow.init();
}
SavedModelBundle:
tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java 2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java 2020-08-16 19:18:29.544101268 -0400
@@ -20,6 +20,15 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
import com.google.protobuf.InvalidProtocolBufferException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
@@ -30,8 +39,16 @@
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.proto.framework.ConfigProto;
+import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
+import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
import org.tensorflow.proto.framework.RunOptions;
+import org.tensorflow.proto.framework.SavedModel;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.proto.framework.TensorShapeProto;
+import org.tensorflow.proto.framework.TensorShapeProto.Dim;
+import org.tensorflow.ndarray.Shape;
/**
* SavedModelBundle represents a model loaded from storage.
@@ -94,6 +111,78 @@
private RunOptions runOptions = null;
}
+ public static final class Exporter {
+
+ public Exporter withTags(String... tags) {
+ this.tags.addAll(Arrays.asList(tags));
+ return this;
+ }
+
+ public Exporter withSignature(Map<String, Operand<?>> inputs, Map<String, Operand<?>> outputs) {
+ return withSignature("serving_default", "tensorflow/serving/predict", inputs, outputs);
+ }
+
+ public Exporter withSignature(String signatureName, String methodName, Map<String, Operand<?>> inputs, Map<String, Operand<?>> outputs) {
+ SignatureDef.Builder signatureDefBuilder = SignatureDef.newBuilder();
+ for (Map.Entry<String, Operand<?>> inputEntry : inputs.entrySet()) {
+ signatureDefBuilder.putInputs(inputEntry.getKey(), toTensorInfo(inputEntry.getValue().asOutput()));
+ }
+ for (Map.Entry<String, Operand<?>> outputEntry : outputs.entrySet()) {
+ signatureDefBuilder.putOutputs(outputEntry.getKey(), toTensorInfo(outputEntry.getValue().asOutput()));
+ }
+ signatureDefBuilder.setMethodName(methodName);
+ metaGraphDefBuilder.putSignatureDef(signatureName, signatureDefBuilder.build());
+ return this;
+ }
+
+ public void export(Session session) throws IOException {
+ Graph graph = session.graph();
+ if (tags.isEmpty()) {
+ tags.add("serve");
+ }
+ // Important: it is imperative to retrieve the graphDef after the saverDef, as the former might add new ops. FIXME Better way for handling this?
+ MetaGraphDef metaGraphDef = metaGraphDefBuilder
+ .setSaverDef(graph.saverDef())
+ .setGraphDef(graph.toGraphDef())
+ .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags))
+ .build();
+
+ // Make sure saved model directories exist
+ Path variableDir = Paths.get(exportDir, "variables");
+ variableDir.toFile().mkdirs();
+
+ // Save variable state, this must be done before we retrieve the `SaverDef` from the graph
+ session.save(variableDir.resolve("variables").toString());
+
+ // Save graph
+ SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build();
+ try (OutputStream file = new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) {
+ savedModelDef.writeTo(file);
+ }
+ }
+
+ Exporter(String exportDir) {
+ this.exportDir = exportDir;
+ }
+
+ private final String exportDir;
+ private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
+ private final List<String> tags = new ArrayList<>();
+
+ private static TensorInfo toTensorInfo(Output<?> operand) {
+ Shape shape = operand.shape();
+ TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
+ for (int i = 0; i < shape.numDimensions(); ++i) {
+ tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
+ }
+ return TensorInfo.newBuilder()
+ .setDtype(DataType.forNumber(operand.dataType().nativeCode()))
+ .setTensorShape(tensorShapeBuilder)
+ .setName(operand.op().name() + ":" + operand.index())
+ .build();
+ }
+ }
+
/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
@@ -125,6 +214,10 @@
return new Loader(exportDir);
}
+ public static Exporter exporter(String exportDir) {
+ return new Exporter(exportDir);
+ }
+
/**
* Returns the <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
@@ -176,7 +269,7 @@
*/
private static SavedModelBundle fromHandle(
TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
- Graph graph = new Graph(graphHandle);
+ Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
Session session = new Session(graph, sessionHandle);
return new SavedModelBundle(graph, session, metaGraphDef);
}
Session:
tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java 2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java 2020-08-15 16:30:35.751805310 -0400
@@ -36,6 +36,8 @@
import java.util.ArrayList;
import java.util.List;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.types.TString;
import static org.tensorflow.Graph.resolveOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.*;
@@ -444,6 +446,14 @@
runner().addTarget(op.op()).run();
}
+ public void save(String prefix) {
+ SaverDef saverDef = graph.saverDef();
+ runner()
+ .addTarget(saverDef.getSaveTensorName())
+ .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
+ .run();
+ }
+
/**
* Output tensors and metadata obtained when executing a session.
*
@@ -463,6 +473,10 @@
public RunMetadata metadata;
}
+ Graph graph() {
+ return graph;
+ }
+
private final Graph graph;
private final Graph.Reference graphRef;
SavedModelBundleTest exercises the exporter(...)
method to save a model:
tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java 2020-07-29 18:34:42.276476950 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java 2020-08-16 19:26:24.250842779 -0400
@@ -15,21 +15,39 @@
package org.tensorflow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
+import java.io.IOException;
import java.net.URISyntaxException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import java.nio.file.Paths;
+import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.tensorflow.exceptions.TensorFlowException;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.ReduceSum;
+import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunOptions;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.types.TFloat32;
/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
public class SavedModelBundleTest {
+ private static final float EPSILON = 1e-7f;
private static final String SAVED_MODEL_PATH;
+
static {
try {
SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString();
@@ -72,6 +90,73 @@
}
}
+ @Test
+ public void save() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+ float reducedSum;
+ FloatNdArray xValue = StdArrays.ndCopyOf(new float[][] { { 0, 1, 2 }, { 3, 4, 5 } });
+ Shape xyShape = Shape.of(2, 3L);
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Placeholder<TFloat32> x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xyShape));
+ Variable<TFloat32> y = tf.variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE));
+ ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
+ Init init = tf.init();
+
+ try (Session s = new Session(g)) {
+ s.run(init);
+ try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
+ Tensor<TFloat32> zTensor = s.runner()
+ .feed(x, xTensor)
+ .fetch(z)
+ .run()
+ .get(0).expect(TFloat32.DTYPE)) {
+ reducedSum = zTensor.data().getFloat();
+ }
+ SavedModelBundle.exporter(testFolder.toString())
+ .withTags("test")
+ .withSignature(Collections.singletonMap("input", x), Collections.singletonMap("reducedSum", z))
+ .export(s);
+ }
+ }
+ assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index"))));
+ assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001"))));
+ assertTrue(Files.exists(testFolder.resolve("saved_model.pb")));
+
+ // Reload the model just saved and validate its data
+ try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString(), "test")) {
+ assertNotNull(savedModel.metaGraphDef());
+ assertNotNull(savedModel.metaGraphDef().getSaverDef());
+ assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
+
+ SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
+ assertNotNull(signature);
+ assertEquals(1, signature.getInputsCount());
+ assertEquals(1, signature.getOutputsCount());
+
+ TensorInfo inputInfo = signature.getInputsMap().get("input");
+ assertNotNull(inputInfo);
+ assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
+ for (int i = 0; i < xyShape.numDimensions(); ++i) {
+ assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
+ }
+
+ TensorInfo outputInfo = signature.getOutputsMap().get("reducedSum");
+ assertNotNull(outputInfo);
+ assertEquals(0, outputInfo.getTensorShape().getDimCount());
+
+ // Run the saved model just loaded and make sure it returns the same result as before
+ try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
+ Tensor<TFloat32> zTensor = savedModel.session().runner()
+ .feed(inputInfo.getName(), xTensor)
+ .fetch(outputInfo.getName())
+ .run()
+ .get(0).expect(TFloat32.DTYPE)) {
+ assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON);
+ }
+ }
+ }
+
private static RunOptions sillyRunOptions() {
return RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
If anything looks incorrect, would a colleague kindly let me know? I plan to test soon. Thanks again!
from java.
Preliminary tests of model save/load pass. Test output shows orig_reduced_sum=17.154881 loaded_reduced_sum=17.154881
, so the loaded model gives the same result as the original saved model. Good:
root@11ade3e6890f:tensorflow-java-savemodel# java -jar my.jar tf --verbose
DEBUG 1597658225600: Started ndarray TensorflowJavaTest.
DEBUG 1597658226181: matrix3d rank 3
DEBUG 1597658226186: Finished ndarray TensorflowJavaTest.
DEBUG 1597658226187: Started graph TensorflowJavaTest.
DEBUG 1597658230142: fetch_test fetched.data.getInt(0)=3 is 3? true
DEBUG 1597658230143: fetch_test fetched.data.getInt(1)=4 is 4? true
DEBUG 1597658230357: feed_test fetched.data=8,6,4,2
DEBUG 1597658230365: Started save_test saving to dir=tf-saved-model-export-test
DEBUG 1597658231576: Finished save_test
DEBUG 1597658231584: Started load_test
2020-08-17 09:57:11.595866: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /tmp/tf-saved-model-export-test8180187955358911902
2020-08-17 09:57:11.596308: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { test }
2020-08-17 09:57:11.596360: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:295] Reading SavedModel debug info (if present) from: /tmp/tf-saved-model-export-test8180187955358911902
2020-08-17 09:57:11.597165: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle.
2020-08-17 09:57:11.614106: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:364] SavedModel load for tags { test }; Status: success: OK. Took 22272 microseconds.
DEBUG 1597658231632: load_test xy_shape dimension_count=2 input_info dimension_count=2
DEBUG 1597658231640: load_test xy_shape dimension_sizes=(2,3) input_info dimension_sizes=(2,3)
DEBUG 1597658231643: load_test output_info dimension=0
DEBUG 1597658231659: load_test orig_reduced_sum=17.154881 loaded_reduced_sum=17.154881
DEBUG 1597658231660: Finished load_test
DEBUG 1597658231661: Finished graph TensorflowJavaTest.
The main run() method for testing is now:
def run(): Int = {
var ret = 0
debug("Started ndarray TensorflowJavaTest.")
// run simple data buffers test as in https://github.com/tensorflow/java/tree/master/ndarray
val matrix3d = org.tensorflow.ndarray.NdArrays.ofInts( org.tensorflow.ndarray.Shape.of(2, 3, 2) )
debug("matrix3d rank " + matrix3d.rank)
debug("Finished ndarray TensorflowJavaTest.")
debug("Started graph TensorflowJavaTest.")
fetch_and_feed_tests()
save_and_load_tests()
debug("Finished graph TensorflowJavaTest.")
return ret
}
save_and_load_tests is:
private def save_and_load_tests(): Unit = {
val graph = new org.tensorflow.Graph() // defined in https://github.com/tensorflow/java/blob/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
val session = new org.tensorflow.Session(graph)
val tf = org.tensorflow.op.Ops.create(graph)
val (test_folder, x_value, xy_shape, z_tensor, reduced_sum) = save_test(graph, session, tf)
val loaded_model = load_test(session, test_folder, x_value, xy_shape, z_tensor, reduced_sum)
}
Model save is tested in save_test, which relies on @karllessard's SavedModelBundle exporter(...)
method, thank you!
private def save_test(graph: org.tensorflow.Graph, session: org.tensorflow.Session, tf: org.tensorflow.op.Ops,
temp_dirname: String = "tf-saved-model-export-test"):
(Path, org.tensorflow.ndarray.FloatNdArray, org.tensorflow.ndarray.Shape,
org.tensorflow.Tensor[org.tensorflow.types.TFloat32], Float) = {
// adapted from save() in https://github.com/karllessard/tensorflow-java/blob/bdb0420fd72264457a6d47e5e6c6f7a2e56b4271/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java#L116
// use exporter() to save model.
debug("Started save_test saving to dir=%s".format(temp_dirname))
val test_folder: Path = Files.createTempDirectory(temp_dirname)
//val x_value = org.tensorflow.tools.ndarray.StdArrays.ndCopyOf(new float[][] { { 0, 1, 2 }, { 3, 4, 5 } }) // FloatNdArray
val x_value: org.tensorflow.ndarray.FloatNdArray = org.tensorflow.ndarray.StdArrays.ndCopyOf(Array( Array(0.0f,1.0f,2.0f),Array(3.0f,4.0f,5.0f) ))
val xy_shape = org.tensorflow.ndarray.Shape.of(2, 3l)
val x = tf.placeholder(org.tensorflow.types.TFloat32.DTYPE, org.tensorflow.op.core.Placeholder.shape(xy_shape))
val y = tf.variable(tf.random.randomUniform(tf.constant(xy_shape), org.tensorflow.types.TFloat32.DTYPE))
val z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1))
val init = tf.init
session.run(init)
val x_tensor = org.tensorflow.types.TFloat32.tensorOf(x_value)
val z_tensor = session.runner
.feed(x, x_tensor)
.fetch(z)
.run
.get(0).expect(org.tensorflow.types.TFloat32.DTYPE)
val reduced_sum: Float = z_tensor.data.getFloat()
org.tensorflow.SavedModelBundle.exporter(test_folder.toString)
.withTags("test")
.withSignature(Collections.singletonMap("input", x), Collections.singletonMap("reducedSum", z))
.export(session)
debug("Finished save_test")
(test_folder, x_value, xy_shape, z_tensor, reduced_sum)
}
Model load is tested in load_test, which is heavily based on Karl's helpful unit tests:
private def load_test(session: org.tensorflow.Session,
test_folder: Path, x_value: org.tensorflow.ndarray.FloatNdArray,
xy_shape: org.tensorflow.ndarray.Shape,
orig_z_tensor: org.tensorflow.Tensor[org.tensorflow.types.TFloat32],
orig_reduced_sum: Float):
org.tensorflow.SavedModelBundle = {
debug("Started load_test")
// like save_test, adapted from https://github.com/karllessard/tensorflow-java/blob/bdb0420fd72264457a6d47e5e6c6f7a2e56b4271/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java#L116
val saved_model: org.tensorflow.SavedModelBundle = org.tensorflow.SavedModelBundle.load(test_folder.toString, "test")
val signature = saved_model.metaGraphDef.getSignatureDefMap.get("serving_default") // SignatureDef
val input_info = signature.getInputsMap.get("input") // TensorInfo
debug("load_test xy_shape dimension_count=%d input_info dimension_count=%d".format(xy_shape.numDimensions, input_info.getTensorShape.getDimCount)) // dimension counts should match
debug("load_test xy_shape dimension_sizes=(%s) input_info dimension_sizes=(%s)".format(
(0 until xy_shape.numDimensions).map( xy_shape.size(_) ).mkString(","),
(0 until input_info.getTensorShape.getDimCount).map( input_info.getTensorShape.getDim(_).getSize ).mkString(","))) // dimension sizes should match
val output_info = signature.getOutputsMap.get("reducedSum") // TensorInfo
debug("load_test output_info dimension=%d".format(output_info.getTensorShape.getDimCount))
val x_tensor = org.tensorflow.types.TFloat32.tensorOf(x_value)
val z_tensor = saved_model.session.runner
.feed(input_info.getName, x_tensor)
.fetch(output_info.getName)
.run
.get(0).expect(org.tensorflow.types.TFloat32.DTYPE)
debug("load_test orig_reduced_sum=%f loaded_reduced_sum=%f".format(orig_reduced_sum, z_tensor.data.getFloat()))
debug("Finished load_test")
saved_model
}
The result is a saved model in /tmp. Would any colleagues kindly give their opinion -- does this look reasonable?
root@11ade3e6890f:tensorflow-java-savemodel# ls -hal /tmp/tf-saved-model-export-test8180187955358911902
total 16K
drwx------ 3 root root 4.0K Aug 17 09:57 .
drwxrwxrwt 1 root root 4.0K Aug 17 09:57 ..
-rw-r--r-- 1 root root 1.3K Aug 17 09:57 saved_model.pb
drwxr-xr-x 2 root root 4.0K Aug 17 09:57 variables
root@11ade3e6890f:tensorflow-java-savemodel# ls -hal /tmp/tf-saved-model-export-test8180187955358911902/variables
total 16K
drwxr-xr-x 2 root root 4.0K Aug 17 09:57 .
drwx------ 3 root root 4.0K Aug 17 09:57 ..
-rw-r--r-- 1 root root 24 Aug 17 09:57 variables.data-00000-of-00001
-rw-r--r-- 1 root root 132 Aug 17 09:57 variables.index
from java.
Related Issues (20)
- Custom gradient registration is broken in Windows
- How to deal with the cold start problem? HOT 6
- [THIS ISSUE WAS AN EQUIVOCATION] Tensorflow 0.4.0 docs say it should be compatible with JAva 8 but it is not HOT 1
- Issue with loading model from: tensorflow_decision_forests HOT 5
- protobuf-java 3.19.4 contains 3 high vulnerabilities HOT 2
- Complex Tensor Implementation Missing? HOT 5
- ivy dependency not working on windows or linux, native TF code not found on classpath HOT 5
- Could not load dynamic library 'xxxxx'; dlerror: xxxxx.dll not found HOT 5
- org.tensorflow.TensorFlowException: Can't parse /<modelPath>/<somePathToFolder>/saved_model.pb as binary proto - JDK 17 HOT 15
- Compiling from source, cuDNN version is not compatible? How can I change the cuDNN compile version? HOT 2
- SavedModelBundle Unable to Load Models with coo_sparse Encoded Input HOT 1
- Tensor type issue HOT 3
- Unable to build the project using 'mvn install ' command HOT 4
- how to use importGraphDef to load model.pb file? HOT 2
- Modular Java app can't create tensor object HOT 3
- Read/Write method of DataBuffer is against intuition HOT 1
- Distributing an Apple Silicon binary HOT 2
- Error when using tensorflow-text on tensorflow-core HOT 8
- Reductions on losses that have dynamic size
- No documentation for 1.0.0 HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from java.