Git Product home page Git Product logo

SavedModelBundle exporter(...) build: mvn install [ERROR] tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor location: class org.tensorflow.proto.framework.StructProtos about java HOT 7 CLOSED

tensorflow avatar tensorflow commented on May 21, 2024
SavedModelBundle exporter(...) build: mvn install [ERROR] tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol symbol: variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor location: class org.tensorflow.proto.framework.StructProtos

from java.

Comments (7)

karllessard avatar karllessard commented on May 21, 2024

Hi @aday00 ,

I'm not sure why you are getting this error, it looks like your generated protos are out-of-sync somehow, maybe a mvn clean before can do it? (be aware that will retrigger a full Bazel build though)

But the save_model branch in my personal repo was just something very experimental showing what could be done and I was not really expecting supporting it for real usage. We just synced up on this topic during the SIG session and now we have a clear plan on how we want to design this before merging it to tensorflow/java. This should happen in the next month or so.

from java.

aday00 avatar aday00 commented on May 21, 2024

Thank you kindly @karllessard !

It appears the SIG document is https://docs.google.com/document/d/1RqSIKKFLE8kMSjUuZWsVS4JNCRHIrCYiuWcFZRNk2oU/edit# and mentions this:

Saved models and function graphs
  Shajan changes have been pushed to this branch
  Interesting thread about this topic
  We should build an API that is ready to:
    Serve session-centric models (estimators and TF1.x)
    Serve function-centric models (TF2.x)
    Create function graphs, referenced as an attribute of an op (e.g. `tf.dataset.map`)
    Create function graphs with signature (like `tf.function` does)
    Should we save session-centric models from Java? C API only supports this “officially” for now
    Confirmation: lets model the API with functions while still loading/saving session-centric graphs

It appears @Shajan's changes are discussed in #89 and the commit d845856

I don't see a unit test, so I'm not sure how to use this to save a model. Any tips would be greatly appreciated!

In the mean time, I'm happy to continue with the save_model branch and SavedModelBundle exporter(...). Thanks for this pre-alpha feature, and all the time you've invested in the stewardship of this project! Hopefully I can contribute a bit by testing.

from java.

karllessard avatar karllessard commented on May 21, 2024

Yes, the portion for saving the model is not in this branch yet but work will start soon and might differ a bit from what you see in my save_model branch right now because we will opt for a more "function-centric" API.

Keeping you posted! and please feel free to share any feedbacks on the draft version in my repo as they can also apply to the next version.

from java.

aday00 avatar aday00 commented on May 21, 2024

Thanks @karllessard ! Ran mvn clean as suggested:

root@478b80e86d3b:tensorflow-java-karl# mvn clean
[INFO] Scanning for projects...
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Build Order:
[INFO] 
[INFO] TensorFlow Java Parent                                             [pom]
[INFO] TensorFlow Tools Library                                           [jar]
[INFO] TensorFlow Core Parent                                             [pom]
[INFO] TensorFlow Core Annotation Processor                               [jar]
[INFO] TensorFlow Core API Library                                        [jar]
[INFO] TensorFlow Core API Library Platform                               [jar]
[INFO] TensorFlow Framework Library                                       [jar]
[INFO] 
[INFO] -------------------< org.tensorflow:tensorflow-java >-------------------
[INFO] Building TensorFlow Java Parent 0.1.0-SNAPSHOT                     [1/7]
[INFO] --------------------------------[ pom ]---------------------------------
...
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ tensorflow-framework ---
[INFO] Deleting /tmp/docker-share/tensorflow-java-karl/tensorflow-framework/target
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.1.0-SNAPSHOT:
[INFO] 
[INFO] TensorFlow Java Parent ............................. SUCCESS [  5.399 s]
[INFO] TensorFlow Tools Library ........................... SUCCESS [  0.227 s]
[INFO] TensorFlow Core Parent ............................. SUCCESS [  0.044 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [  0.337 s]
[INFO] TensorFlow Core API Library ........................ SUCCESS [ 33.345 s]
[INFO] TensorFlow Core API Library Platform ............... SUCCESS [  0.313 s]
[INFO] TensorFlow Framework Library ....................... SUCCESS [  0.717 s]
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  41.954 s
[INFO] Finished at: 2020-08-14T19:25:24Z
[INFO] ------------------------------------------------------------------------

Clean worked, good. Next, mvn install, but same error:

root@478b80e86d3b:tensorflow-java-karl# mvn install -e
[INFO] Error stacktraces are turned on.
[INFO] Scanning for projects...
...
[INFO] --- maven-compiler-plugin:3.8.0:compile (default-compile) @ tensorflow-core-api ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 1678 source files to /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/target/classes
[INFO] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java: Some input files use unchecked or unsafe operations.
[INFO] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java: Recompile with -Xlint:unchecked for details.
[INFO] -------------------------------------------------------------
[ERROR] COMPILATION ERROR : 
[INFO] -------------------------------------------------------------
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol
  symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
  location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[135,55] cannot find symbol
  symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
  location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[479,57] cannot find symbol
  symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
  location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[485,57] cannot find symbol
  symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
  location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[536,57] cannot find symbol
  symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
  location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[82,65] cannot find symbol
  symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
  location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[88,65] cannot find symbol
  symbol:   variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
  location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[290,67] cannot find symbol
  symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
  location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[296,67] cannot find symbol
  symbol:   variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
  location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[329,67] cannot find symbol
  symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
  location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[INFO] 10 errors 
[INFO] -------------------------------------------------------------
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.1.0-SNAPSHOT:
[INFO] 
[INFO] TensorFlow Java Parent ............................. SUCCESS [  1.561 s]
[INFO] TensorFlow Tools Library ........................... SUCCESS [02:59 min]
[INFO] TensorFlow Core Parent ............................. SUCCESS [  0.132 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [  4.513 s]
[INFO] TensorFlow Core API Library ........................ FAILURE [  15:00 h]
[INFO] TensorFlow Core API Library Platform ............... SKIPPED
[INFO] TensorFlow Framework Library ....................... SKIPPED
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  15:03 h
[INFO] Finished at: 2020-08-15T10:30:05Z
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.0:compile (default-compile) on project tensorflow-core-api: Compilation failure: Compilation failure: 
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[129,55] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[135,55] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
[ERROR]   location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[479,57] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[485,57] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_fieldAccessorTable
[ERROR]   location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/BoundedTensorSpecProto.java:[536,57] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_BoundedTensorSpecProto_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.StructProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[82,65] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[88,65] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
[ERROR]   location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[290,67] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[296,67] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_SaveableObject_fieldAccessorTable
[ERROR]   location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] /tmp/docker-share/tensorflow-java-karl/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/SaveableObject.java:[329,67] cannot find symbol
[ERROR]   symbol:   variable internal_static_tensorflow_SaveableObject_descriptor
[ERROR]   location: class org.tensorflow.proto.framework.SavedObjectGraphProtos
[ERROR] -> [Help 1]

From some SIGJVM email https://groups.google.com/a/tensorflow.org/d/msg/jvm/gGKO-hVS4Pc/LF4rLJOdAQAJ , it seems cherry-picking the save_model patch/changes and applying to the mainline tensorflow-java might work? Thanks also for #101 !

from java.

aday00 avatar aday00 commented on May 21, 2024

The save_model commit is karllessard@bdb0420

That patch won't apply cleanly to mainline tensorflow-java, so I'll try to make the change by hand.

from java.

aday00 avatar aday00 commented on May 21, 2024

Applying the saved_model-related changes by hand, to merge into the mainline tensorflow-java, the build seems to work, and I will test this soon:

root@11ade3e6890f:tensorflow-java-savemodel# mvn install -e
...
[INFO] --- maven-install-plugin:2.4:install (default-install) @ tensorflow-framework ---
[INFO] Installing /tmp/docker-share/tensorflow-java-savemodel/tensorflow-framework/target/tensorflow-framework-0.2.0-SNAPSHOT.jar to /root/.m2/repository/org/tensorflow/tensorflow-framework/0.2.0-SNAPSHOT/tensorflow-framework-0.2.0-SNAPSHOT.jar
[INFO] Installing /tmp/docker-share/tensorflow-java-savemodel/tensorflow-framework/pom.xml to /root/.m2/repository/org/tensorflow/tensorflow-framework/0.2.0-SNAPSHOT/tensorflow-framework-0.2.0-SNAPSHOT.pom
[INFO] ------------------------------------------------------------------------
[INFO] Reactor Summary for TensorFlow Java Parent 0.2.0-SNAPSHOT:
[INFO]
[INFO] TensorFlow Java Parent ............................. SUCCESS [  1.450 s]
[INFO] TensorFlow NdArray Library ......................... SUCCESS [02:26 min]
[INFO] TensorFlow Core Parent ............................. SUCCESS [  0.055 s]
[INFO] TensorFlow Core Annotation Processor ............... SUCCESS [  3.259 s]
[INFO] TensorFlow Core API Library ........................ SUCCESS [07:41 min]
[INFO] TensorFlow Core API Library Platform ............... SUCCESS [  0.873 s]
[INFO] TensorFlow Framework Library ....................... SUCCESS [ 38.209 s]
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  10:53 min
[INFO] Finished at: 2020-08-16T23:37:45Z
[INFO] ------------------------------------------------------------------------

My changes are almost identical to @karllessard's code karllessard@bdb0420. I sometimes made changes that did not alter the function of the code, but made the code more similar to Karl's branch, which has the save_model code. If we want a branch and PR from my work, I am happy to clean this up further. Prior email suggested that others (e.g. Alexey Zinoviev) may want to test model-saving code like this too https://groups.google.com/a/tensorflow.org/d/msg/jvm/gGKO-hVS4Pc/D0z5yu-gAQAJ.

My changes are mostly with changes related to NdArray namespacing. The diffs are below, starting with Ops.java:

tensorflow-java-savemodel# diff -u tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

--- tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java	2020-08-16 19:30:40.402099828 -0400
+++ ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java	2020-08-09 19:39:33.350722050 -0400
@@ -312,10 +312,10 @@
 
   public final ImageOps image;
 
-  public final DataOps data;
-
   public final ShapeOps shape;
 
+  public final DataOps data;
+
   public final IoOps io;
 
   public final DtypesOps dtypes;
@@ -338,10 +338,10 @@
 
   public final SignalOps signal;
 
-  public final QuantizationOps quantization;
-
   public final TrainOps train;
 
+  public final QuantizationOps quantization;
+
   private final Scope scope;
 
   private Ops(Scope scope) {
@@ -349,8 +349,8 @@
     nn = new NnOps(scope);
     summary = new SummaryOps(scope);
     image = new ImageOps(scope);
-    data = new DataOps(scope);
     shape = new ShapeOps(scope);
+    data = new DataOps(scope);
     io = new IoOps(scope);
     dtypes = new DtypesOps(scope);
     xla = new XlaOps(scope);
@@ -362,8 +362,8 @@
     math = new MathOps(scope);
     audio = new AudioOps(scope);
     signal = new SignalOps(scope);
-    quantization = new QuantizationOps(scope);
     train = new TrainOps(scope);
+    quantization = new QuantizationOps(scope);
   }
 
   /**

Graph.java:

tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java	2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java	2020-08-16 19:19:15.860833163 -0400
@@ -44,7 +44,16 @@
 import org.tensorflow.internal.c_api.TF_Status;
 import org.tensorflow.internal.c_api.TF_WhileParams;
 import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Constant;
+import org.tensorflow.op.core.NoOp;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.train.Restore;
+import org.tensorflow.op.train.Save;
 import org.tensorflow.proto.framework.GraphDef;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.types.TString;
 
 
 /**
@@ -67,6 +76,11 @@
     this.nativeHandle = nativeHandle;
   }
 
+  Graph(TF_Graph nativeHandle, SaverDef saverDef) {
+    this(nativeHandle);
+    this.saverDef = saverDef;
+  }
+
   /**
    * Release resources associated with the Graph.
    *
@@ -287,6 +301,17 @@
     return addGradients(null, new Output<?>[] {y}, x, null);
   }
 
+  public SaverDef saverDef() {
+    if (saverDef == null) {
+      synchronized (this) {
+        if (saverDef == null) {
+          saverDef = addVariableSaver(this);
+        }
+      }
+    }
+    return saverDef;
+  }
+
   /**
    * Used to instantiate an abstract class which overrides the buildSubgraph method to build a
    * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
@@ -405,6 +430,7 @@
   private final Object nativeHandleLock = new Object();
   private TF_Graph nativeHandle;
   private int refcount = 0;
+  private SaverDef saverDef;
 
   private final List<Op> initializers = new ArrayList<>();
 
@@ -726,6 +752,53 @@
     }
   }
 
+  private static SaverDef addVariableSaver(Graph graph) {
+    Ops tf = Ops.create(graph).withSubScope("save");
+
+    List<String> varNames = new ArrayList<>();
+    List<Operand<?>> varOutputs = new ArrayList<>();
+    List<DataType<?>> varTypes = new ArrayList<>();
+
+    for (Iterator<Operation> iter = graph.operations(); iter.hasNext();) {
+      Operation op = iter.next();
+      if (op.type().equals("VariableV2")) {
+        varNames.add(op.name());
+        varOutputs.add(op.output(0));
+        varTypes.add(op.output(0).dataType());
+      }
+    }
+
+    // FIXME Need an easier way to initialize an NdArray from a list
+    String[] tmp = new String[varNames.size()];
+    Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
+    Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
+
+    Placeholder<TString> saveFilename = tf.placeholder(TString.DTYPE);
+    Save saveVariables = tf.train.save(
+        saveFilename,
+        varNamesTensor,
+        varSlices,
+        varOutputs
+    );
+    Restore restoreVariables = tf.train.restore(
+        saveFilename,
+        varNamesTensor,
+        varSlices,
+        varTypes
+    );
+    List<Op> restoreOps = new ArrayList<>(varOutputs.size());
+    for (int i = 0; i < varOutputs.size(); ++i) {
+      restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
+    }
+    NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
+
+    return SaverDef.newBuilder()
+        .setFilenameTensorName(saveFilename.op().name())
+        .setSaveTensorName(saveVariables.op().name())
+        .setRestoreOpName(restoreAll.op().name())
+        .build();
+  }
+
   static {
     TensorFlow.init();
   }

SavedModelBundle:

tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java	2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java	2020-08-16 19:18:29.544101268 -0400
@@ -20,6 +20,15 @@
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
 
 import com.google.protobuf.InvalidProtocolBufferException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.PointerPointer;
 import org.bytedeco.javacpp.PointerScope;
@@ -30,8 +39,16 @@
 import org.tensorflow.internal.c_api.TF_SessionOptions;
 import org.tensorflow.internal.c_api.TF_Status;
 import org.tensorflow.proto.framework.ConfigProto;
+import org.tensorflow.proto.framework.DataType;
 import org.tensorflow.proto.framework.MetaGraphDef;
+import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
 import org.tensorflow.proto.framework.RunOptions;
+import org.tensorflow.proto.framework.SavedModel;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.proto.framework.TensorShapeProto;
+import org.tensorflow.proto.framework.TensorShapeProto.Dim;
+import org.tensorflow.ndarray.Shape;
 
 /**
  * SavedModelBundle represents a model loaded from storage.
@@ -94,6 +111,78 @@
     private RunOptions runOptions = null;
   }
 
+  public static final class Exporter {
+
+    public Exporter withTags(String... tags) {
+      this.tags.addAll(Arrays.asList(tags));
+      return this;
+    }
+
+    public Exporter withSignature(Map<String, Operand<?>> inputs,  Map<String, Operand<?>> outputs) {
+      return withSignature("serving_default", "tensorflow/serving/predict", inputs, outputs);
+    }
+
+    public Exporter withSignature(String signatureName, String methodName, Map<String, Operand<?>> inputs,  Map<String, Operand<?>> outputs) {
+      SignatureDef.Builder signatureDefBuilder = SignatureDef.newBuilder();
+      for (Map.Entry<String, Operand<?>> inputEntry : inputs.entrySet()) {
+        signatureDefBuilder.putInputs(inputEntry.getKey(), toTensorInfo(inputEntry.getValue().asOutput()));
+      }
+      for (Map.Entry<String, Operand<?>> outputEntry : outputs.entrySet()) {
+        signatureDefBuilder.putOutputs(outputEntry.getKey(), toTensorInfo(outputEntry.getValue().asOutput()));
+      }
+      signatureDefBuilder.setMethodName(methodName);
+      metaGraphDefBuilder.putSignatureDef(signatureName, signatureDefBuilder.build());
+      return this;
+    }
+
+    public void export(Session session) throws IOException {
+      Graph graph = session.graph();
+      if (tags.isEmpty()) {
+        tags.add("serve");
+      }
+      // Important: it is imperative to retrieve the graphDef after the saverDef, as the former might add new ops. FIXME Better way for handling this?
+      MetaGraphDef metaGraphDef = metaGraphDefBuilder
+          .setSaverDef(graph.saverDef())
+          .setGraphDef(graph.toGraphDef())
+          .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags))
+          .build();
+
+      // Make sure saved model directories exist
+      Path variableDir = Paths.get(exportDir, "variables");
+      variableDir.toFile().mkdirs();
+
+      // Save variable state, this must be done before we retrieve the `SaverDef` from the graph
+      session.save(variableDir.resolve("variables").toString());
+
+      // Save graph
+      SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build();
+      try (OutputStream file = new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) {
+        savedModelDef.writeTo(file);
+      }
+    }
+
+    Exporter(String exportDir) {
+      this.exportDir = exportDir;
+    }
+
+    private final String exportDir;
+    private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
+    private final List<String> tags = new ArrayList<>();
+
+    private static TensorInfo toTensorInfo(Output<?> operand) {
+      Shape shape = operand.shape();
+      TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
+      for (int i = 0; i < shape.numDimensions(); ++i) {
+        tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
+      }
+      return TensorInfo.newBuilder()
+          .setDtype(DataType.forNumber(operand.dataType().nativeCode()))
+          .setTensorShape(tensorShapeBuilder)
+          .setName(operand.op().name() + ":" + operand.index())
+          .build();
+    }
+  }
+
   /**
    * Load a saved model from an export directory. The model that is being loaded should be created
    * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
@@ -125,6 +214,10 @@
     return new Loader(exportDir);
   }
 
+  public static Exporter exporter(String exportDir) {
+    return new Exporter(exportDir);
+  }
+
   /**
    * Returns the <a
    * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
@@ -176,7 +269,7 @@
    */
   private static SavedModelBundle fromHandle(
       TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
-    Graph graph = new Graph(graphHandle);
+    Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
     Session session = new Session(graph, sessionHandle);
     return new SavedModelBundle(graph, session, metaGraphDef);
   }

Session:

tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java	2020-07-29 18:34:42.260476643 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java	2020-08-15 16:30:35.751805310 -0400
@@ -36,6 +36,8 @@
 
 import java.util.ArrayList;
 import java.util.List;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.types.TString;
 
 import static org.tensorflow.Graph.resolveOutputs;
 import static org.tensorflow.internal.c_api.global.tensorflow.*;
@@ -444,6 +446,14 @@
     runner().addTarget(op.op()).run();
   }
 
+  public void save(String prefix) {
+    SaverDef saverDef = graph.saverDef();
+    runner()
+        .addTarget(saverDef.getSaveTensorName())
+        .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
+        .run();
+  }
+
   /**
    * Output tensors and metadata obtained when executing a session.
    *
@@ -463,6 +473,10 @@
     public RunMetadata metadata;
   }
 
+  Graph graph() {
+    return graph;
+  }
+
   private final Graph graph;
   private final Graph.Reference graphRef;

SavedModelBundleTest exercises the exporter(...) method to save a model:

tensorflow-java-savemodel# diff -u {../tensorflow-java/,./}tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

--- ../tensorflow-java/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java	2020-07-29 18:34:42.276476950 -0400
+++ ./tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java	2020-08-16 19:26:24.250842779 -0400
@@ -15,21 +15,39 @@
 
 package org.tensorflow;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 
+import java.io.IOException;
 import java.net.URISyntaxException;
+import java.nio.file.Files;
+import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.util.Collections;
 import org.junit.jupiter.api.Test;
 import org.tensorflow.exceptions.TensorFlowException;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.ReduceSum;
+import org.tensorflow.op.core.Variable;
 import org.tensorflow.proto.framework.ConfigProto;
 import org.tensorflow.proto.framework.RunOptions;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.types.TFloat32;
 
 /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
 public class SavedModelBundleTest {
 
+  private static final float EPSILON = 1e-7f;
   private static final String SAVED_MODEL_PATH;
+
   static {
     try {
       SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString();
@@ -72,6 +90,73 @@
     }
   }
 
+  @Test
+  public void save() throws IOException {
+    Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+    float reducedSum;
+    FloatNdArray xValue = StdArrays.ndCopyOf(new float[][] { { 0, 1, 2 }, { 3, 4, 5 } });
+    Shape xyShape = Shape.of(2, 3L);
+    try (Graph g = new Graph()) {
+      Ops tf = Ops.create(g);
+      Placeholder<TFloat32> x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xyShape));
+      Variable<TFloat32> y = tf.variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE));
+      ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
+      Init init = tf.init();
+
+      try (Session s = new Session(g)) {
+        s.run(init);
+        try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
+            Tensor<TFloat32> zTensor = s.runner()
+                .feed(x, xTensor)
+                .fetch(z)
+                .run()
+                .get(0).expect(TFloat32.DTYPE)) {
+          reducedSum = zTensor.data().getFloat();
+        }
+        SavedModelBundle.exporter(testFolder.toString())
+            .withTags("test")
+            .withSignature(Collections.singletonMap("input", x), Collections.singletonMap("reducedSum", z))
+            .export(s);
+      }
+    }
+    assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index"))));
+    assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001"))));
+    assertTrue(Files.exists(testFolder.resolve("saved_model.pb")));
+
+    // Reload the model just saved and validate its data
+    try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString(), "test")) {
+      assertNotNull(savedModel.metaGraphDef());
+      assertNotNull(savedModel.metaGraphDef().getSaverDef());
+      assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
+
+      SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
+      assertNotNull(signature);
+      assertEquals(1, signature.getInputsCount());
+      assertEquals(1, signature.getOutputsCount());
+
+      TensorInfo inputInfo = signature.getInputsMap().get("input");
+      assertNotNull(inputInfo);
+      assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
+      for (int i = 0; i < xyShape.numDimensions(); ++i) {
+        assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
+      }
+
+      TensorInfo outputInfo = signature.getOutputsMap().get("reducedSum");
+      assertNotNull(outputInfo);
+      assertEquals(0, outputInfo.getTensorShape().getDimCount());
+
+      // Run the saved model just loaded and make sure it returns the same result as before
+      try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
+          Tensor<TFloat32> zTensor = savedModel.session().runner()
+              .feed(inputInfo.getName(), xTensor)
+              .fetch(outputInfo.getName())
+              .run()
+              .get(0).expect(TFloat32.DTYPE)) {
+        assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON);
+      }
+    }
+  }
+
   private static RunOptions sillyRunOptions() {
     return RunOptions.newBuilder()
         .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)

If anything looks incorrect, would a colleague kindly let me know? I plan to test soon. Thanks again!

from java.

aday00 avatar aday00 commented on May 21, 2024

Preliminary tests of model save/load pass. Test output shows orig_reduced_sum=17.154881 loaded_reduced_sum=17.154881, so the loaded model gives the same result as the original saved model. Good:

root@11ade3e6890f:tensorflow-java-savemodel# java -jar my.jar tf --verbose
DEBUG 1597658225600: Started ndarray TensorflowJavaTest.
DEBUG 1597658226181: matrix3d rank 3
DEBUG 1597658226186: Finished ndarray TensorflowJavaTest.
DEBUG 1597658226187: Started graph TensorflowJavaTest.
DEBUG 1597658230142: fetch_test fetched.data.getInt(0)=3 is 3? true
DEBUG 1597658230143: fetch_test fetched.data.getInt(1)=4 is 4? true
DEBUG 1597658230357: feed_test fetched.data=8,6,4,2
DEBUG 1597658230365: Started save_test saving to dir=tf-saved-model-export-test
DEBUG 1597658231576: Finished save_test
DEBUG 1597658231584: Started load_test
2020-08-17 09:57:11.595866: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /tmp/tf-saved-model-export-test8180187955358911902
2020-08-17 09:57:11.596308: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { test }
2020-08-17 09:57:11.596360: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:295] Reading SavedModel debug info (if present) from: /tmp/tf-saved-model-export-test8180187955358911902
2020-08-17 09:57:11.597165: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle.
2020-08-17 09:57:11.614106: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:364] SavedModel load for tags { test }; Status: success: OK. Took 22272 microseconds.
DEBUG 1597658231632: load_test  xy_shape dimension_count=2  input_info dimension_count=2
DEBUG 1597658231640: load_test  xy_shape dimension_sizes=(2,3)  input_info dimension_sizes=(2,3)
DEBUG 1597658231643: load_test output_info dimension=0
DEBUG 1597658231659: load_test orig_reduced_sum=17.154881 loaded_reduced_sum=17.154881
DEBUG 1597658231660: Finished load_test
DEBUG 1597658231661: Finished graph TensorflowJavaTest.

The main run() method for testing is now:

    def run(): Int = {
      var ret = 0
      debug("Started ndarray TensorflowJavaTest.")
      // run simple data buffers test as in https://github.com/tensorflow/java/tree/master/ndarray
      val matrix3d = org.tensorflow.ndarray.NdArrays.ofInts( org.tensorflow.ndarray.Shape.of(2, 3, 2) )
      debug("matrix3d rank " + matrix3d.rank)
      debug("Finished ndarray TensorflowJavaTest.")

      debug("Started graph TensorflowJavaTest.")
      fetch_and_feed_tests()
      save_and_load_tests()


      debug("Finished graph TensorflowJavaTest.")
      return ret
    }

save_and_load_tests is:

    private def save_and_load_tests(): Unit = {
      val graph = new org.tensorflow.Graph() // defined in https://github.com/tensorflow/java/blob/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
      val session = new org.tensorflow.Session(graph)
      val tf = org.tensorflow.op.Ops.create(graph)

      val (test_folder, x_value, xy_shape, z_tensor, reduced_sum) = save_test(graph, session, tf)
      val loaded_model = load_test(session, test_folder, x_value, xy_shape, z_tensor, reduced_sum)
    }

Model save is tested in save_test, which relies on @karllessard's SavedModelBundle exporter(...) method, thank you!

    private def save_test(graph: org.tensorflow.Graph, session: org.tensorflow.Session, tf: org.tensorflow.op.Ops,
                          temp_dirname: String = "tf-saved-model-export-test"):
                          (Path, org.tensorflow.ndarray.FloatNdArray, org.tensorflow.ndarray.Shape,
                           org.tensorflow.Tensor[org.tensorflow.types.TFloat32], Float) = {
      // adapted from save() in https://github.com/karllessard/tensorflow-java/blob/bdb0420fd72264457a6d47e5e6c6f7a2e56b4271/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java#L116
      // use exporter() to save model.
      debug("Started save_test saving to dir=%s".format(temp_dirname))
      val test_folder: Path = Files.createTempDirectory(temp_dirname)

      //val x_value = org.tensorflow.tools.ndarray.StdArrays.ndCopyOf(new float[][] { { 0, 1, 2 }, { 3, 4, 5 } }) // FloatNdArray
      val x_value: org.tensorflow.ndarray.FloatNdArray = org.tensorflow.ndarray.StdArrays.ndCopyOf(Array( Array(0.0f,1.0f,2.0f),Array(3.0f,4.0f,5.0f) ))
      val xy_shape = org.tensorflow.ndarray.Shape.of(2, 3l)

      val x = tf.placeholder(org.tensorflow.types.TFloat32.DTYPE, org.tensorflow.op.core.Placeholder.shape(xy_shape))
      val y = tf.variable(tf.random.randomUniform(tf.constant(xy_shape), org.tensorflow.types.TFloat32.DTYPE))
      val z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1))
      val init = tf.init

      session.run(init)
      val x_tensor = org.tensorflow.types.TFloat32.tensorOf(x_value)
      val z_tensor = session.runner
                            .feed(x, x_tensor)
                            .fetch(z)
                            .run
                            .get(0).expect(org.tensorflow.types.TFloat32.DTYPE)
      val reduced_sum: Float = z_tensor.data.getFloat()
      org.tensorflow.SavedModelBundle.exporter(test_folder.toString)
                                     .withTags("test")
                                     .withSignature(Collections.singletonMap("input", x), Collections.singletonMap("reducedSum", z))
                                     .export(session)
      debug("Finished save_test")
      (test_folder, x_value, xy_shape, z_tensor, reduced_sum)
    }

Model load is tested in load_test, which is heavily based on Karl's helpful unit tests:

    private def load_test(session: org.tensorflow.Session,
                          test_folder: Path, x_value: org.tensorflow.ndarray.FloatNdArray,
                          xy_shape: org.tensorflow.ndarray.Shape,
                          orig_z_tensor: org.tensorflow.Tensor[org.tensorflow.types.TFloat32],
                          orig_reduced_sum: Float):
                          org.tensorflow.SavedModelBundle = {
      debug("Started load_test")
      // like save_test, adapted from https://github.com/karllessard/tensorflow-java/blob/bdb0420fd72264457a6d47e5e6c6f7a2e56b4271/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java#L116
      val saved_model: org.tensorflow.SavedModelBundle = org.tensorflow.SavedModelBundle.load(test_folder.toString, "test")
      val signature = saved_model.metaGraphDef.getSignatureDefMap.get("serving_default") // SignatureDef
      val input_info = signature.getInputsMap.get("input") // TensorInfo

      debug("load_test  xy_shape dimension_count=%d  input_info dimension_count=%d".format(xy_shape.numDimensions, input_info.getTensorShape.getDimCount)) // dimension counts should match

      debug("load_test  xy_shape dimension_sizes=(%s)  input_info dimension_sizes=(%s)".format(
            (0 until xy_shape.numDimensions).map( xy_shape.size(_) ).mkString(","),
            (0 until input_info.getTensorShape.getDimCount).map( input_info.getTensorShape.getDim(_).getSize ).mkString(","))) // dimension sizes should match

      val output_info = signature.getOutputsMap.get("reducedSum") // TensorInfo
      debug("load_test output_info dimension=%d".format(output_info.getTensorShape.getDimCount))

      val x_tensor = org.tensorflow.types.TFloat32.tensorOf(x_value)
      val z_tensor = saved_model.session.runner
                                        .feed(input_info.getName, x_tensor)
                                        .fetch(output_info.getName)
                                        .run
                                        .get(0).expect(org.tensorflow.types.TFloat32.DTYPE)

      debug("load_test orig_reduced_sum=%f loaded_reduced_sum=%f".format(orig_reduced_sum, z_tensor.data.getFloat()))

      debug("Finished load_test")
      saved_model
    }

The result is a saved model in /tmp. Would any colleagues kindly give their opinion -- does this look reasonable?

root@11ade3e6890f:tensorflow-java-savemodel# ls -hal /tmp/tf-saved-model-export-test8180187955358911902
total 16K
drwx------ 3 root root 4.0K Aug 17 09:57 .
drwxrwxrwt 1 root root 4.0K Aug 17 09:57 ..
-rw-r--r-- 1 root root 1.3K Aug 17 09:57 saved_model.pb
drwxr-xr-x 2 root root 4.0K Aug 17 09:57 variables
root@11ade3e6890f:tensorflow-java-savemodel# ls -hal /tmp/tf-saved-model-export-test8180187955358911902/variables    
total 16K
drwxr-xr-x 2 root root 4.0K Aug 17 09:57 .
drwx------ 3 root root 4.0K Aug 17 09:57 ..
-rw-r--r-- 1 root root   24 Aug 17 09:57 variables.data-00000-of-00001
-rw-r--r-- 1 root root  132 Aug 17 09:57 variables.index

from java.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.