Background introduction
This article is a reference case for tensorflow serving java api, basically introducing the use of the core API of TFS. The case is divided into three parts:
- Dynamic Update Model: Used to load the model dynamically when TFS is in runtime.
- Getting model state: Used to get basic information about the loaded model.
- Online model prediction: online prediction, classification and other operations, focusing on online prediction.
_Because the prediction of the model needs to refer to the internal variables of the model, the metadata of the TF model can be obtained through the REST interface of TFS before the RPC request object of TFS can be constructed.
Introduction to TFS usage
Model Source Data Acquisition
curl http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/metadata
Explain:
- Reference to TFS REST API
- Return results refer to the TF model structure.
public static void getModelStatus() { // 1. Setting up host and port of access RPC protocol ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); // 2. Building PredictionService Blocking Stub Object PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub = PredictionServiceGrpc.newBlockingStub(channel); // 3. Setting up the model to be acquired Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder() .setName("wdl_model").build(); // 4. Constructing requests for metadata acquisition GetModelMetadata.GetModelMetadataRequest modelMetadataRequest = GetModelMetadata.GetModelMetadataRequest.newBuilder() .setModelSpec(modelSpec) .addAllMetadataField(Arrays.asList("signature_def")) .build(); // 5. Access to metadata GetModelMetadata.GetModelMetadataResponse getModelMetadataResponse = predictionServiceBlockingStub.getModelMetadata(modelMetadataRequest); channel.shutdownNow(); }
Explain:
- Model.ModelSpec.newBuilder binds the name of the model to be accessed.
- The signature_def field in metadata returned by the addAllMetadata Field binding curl command in GetModel Metadata Request.
Dynamic update model
public static void addNewModel() { // 1. Constructing Dynamic Updating Model 1 ModelServerConfigOuterClass.ModelConfig modelConfig1 = ModelServerConfigOuterClass.ModelConfig.newBuilder() .setBasePath("/models/new_model") .setName("new_model") .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW) .build(); // 2. Constructing Dynamic Updating Model 2 ModelServerConfigOuterClass.ModelConfig modelConfig2 = ModelServerConfigOuterClass.ModelConfig.newBuilder() .setBasePath("/models/wdl_model") .setName("wdl_model") .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW) .build(); // 3. Merge dynamic update model into ModelConfigList object ModelServerConfigOuterClass.ModelConfigList modelConfigList = ModelServerConfigOuterClass.ModelConfigList.newBuilder() .addConfig(modelConfig1) .addConfig(modelConfig2) .build(); // 4. Add ModelConfigList to the ModelServerConfig object ModelServerConfigOuterClass.ModelServerConfig modelServerConfig = ModelServerConfigOuterClass.ModelServerConfig.newBuilder() .setModelConfigList(modelConfigList) .build(); // 5. Build ReloadConfigRequest and bind ModelServerConfig objects. ModelManagement.ReloadConfigRequest reloadConfigRequest = ModelManagement.ReloadConfigRequest.newBuilder() .setConfig(modelServerConfig) .build(); // 6. Building access handles for modelService Blocking Stub ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); ModelServiceGrpc.ModelServiceBlockingStub modelServiceBlockingStub = ModelServiceGrpc.newBlockingStub(channel); ModelManagement.ReloadConfigResponse reloadConfigResponse = modelServiceBlockingStub.handleReloadConfigRequest(reloadConfigRequest); System.out.println(reloadConfigResponse.getStatus().getErrorMessage()); channel.shutdownNow(); }
Explain:
- Dynamic updating model is a full load of model. To publish B model dynamically after publishing A model, it is necessary to transfer the information of model A and B at the same time.
- Again, we need full update, full update, full update!!!
Online Model Prediction
public static void doPredict() throws Exception { // 1. Building feature s Map<String, Feature> featureMap = new HashMap<>(); featureMap.put("match_type", feature("")); featureMap.put("position", feature(0.0f)); featureMap.put("brand_prefer_1d", feature(0.0f)); featureMap.put("brand_prefer_1m", feature(0.0f)); featureMap.put("brand_prefer_1w", feature(0.0f)); featureMap.put("brand_prefer_2w", feature(0.0f)); featureMap.put("browse_norm_score_1d", feature(0.0f)); featureMap.put("browse_norm_score_1w", feature(0.0f)); featureMap.put("browse_norm_score_2w", feature(0.0f)); featureMap.put("buy_norm_score_1d", feature(0.0f)); featureMap.put("buy_norm_score_1w", feature(0.0f)); featureMap.put("buy_norm_score_2w", feature(0.0f)); featureMap.put("cate1_prefer_1d", feature(0.0f)); featureMap.put("cate1_prefer_2d", feature(0.0f)); featureMap.put("cate1_prefer_1m", feature(0.0f)); featureMap.put("cate1_prefer_1w", feature(0.0f)); featureMap.put("cate1_prefer_2w", feature(0.0f)); featureMap.put("cate2_prefer_1d", feature(0.0f)); featureMap.put("cate2_prefer_1m", feature(0.0f)); featureMap.put("cate2_prefer_1w", feature(0.0f)); featureMap.put("cate2_prefer_2w", feature(0.0f)); featureMap.put("cid_prefer_1d", feature(0.0f)); featureMap.put("cid_prefer_1m", feature(0.0f)); featureMap.put("cid_prefer_1w", feature(0.0f)); featureMap.put("cid_prefer_2w", feature(0.0f)); featureMap.put("user_buy_rate_1d", feature(0.0f)); featureMap.put("user_buy_rate_2w", feature(0.0f)); featureMap.put("user_click_rate_1d", feature(0.0f)); featureMap.put("user_click_rate_1w", feature(0.0f)); Features features = Features.newBuilder().putAllFeature(featureMap).build(); Example example = Example.newBuilder().setFeatures(features).build(); // 2. Building Predict requests Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder(); // 3. Construct model request dimension ModelSpec, bind model name and predicted signature Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName("wdl_model"); modelSpecBuilder.setSignatureName("predict"); predictRequestBuilder.setModelSpec(modelSpecBuilder); // 4. Constructing Dimension Information DIM Objects for Predictive Requests TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(300).build(); TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(dim).build(); TensorProto.Builder tensor = TensorProto.newBuilder(); tensor.setTensorShape(shapeProto); tensor.setDtype(DataType.DT_STRING); // 5. Bulk Binding Prediction Request Data for (int i=0; i<300; i++) { tensor.addStringVal(example.toByteString()); } predictRequestBuilder.putInputs("examples", tensor.build()); // 6. PredictionService Blocking Stub Object Prediction ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub = PredictionServiceGrpc.newBlockingStub(channel); // 7. Implementation Forecast Predict.PredictResponse predictResponse = predictionServiceBlockingStub.predict(predictRequestBuilder.build()); // 8. Resolve the request result List<Float> floatList = predictResponse .getOutputsOrThrow("probabilities") .getFloatValList(); }
Explain:
- The parameters set in the RPC request process of TFS need to consider the data structure of TF model.
- TFS RPC requests are synchronous and asynchronous, and the above only show synchronization.
TF Model Structure
{ "model_spec": { "name": "wdl_model", "signature_name": "", "version": "4" }, "metadata": { "signature_def": { "signature_def": { "predict": { "inputs": { "examples": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "logistic": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/logistic:0" }, "class_ids": { "dtype": "DT_INT64", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/ExpandDims:0" }, "probabilities": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" }, "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/str_classes:0" }, "logits": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "add:0" } }, "method_name": "tensorflow/serving/predict" }, "classification": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/Tile:0" }, "scores": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" } }, "method_name": "tensorflow/serving/classify" }, "regression": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "outputs": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/logistic:0" } }, "method_name": "tensorflow/serving/regress" }, "serving_default": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/Tile:0" }, "scores": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" } }, "method_name": "tensorflow/serving/classify" } } } } }