diff --git a/entry/src/main/cpp/HIAIModelManager.cpp b/entry/src/main/cpp/HIAIModelManager.cpp index 9f7a742468955dcdea4a71e0b1e96b7d63845c7b..f34abaeab184371b1fc9fdf3f2bca758f45e4fd6 100644 --- a/entry/src/main/cpp/HIAIModelManager.cpp +++ b/entry/src/main/cpp/HIAIModelManager.cpp @@ -58,19 +58,26 @@ size_t GetDeviceID() { return deviceID; } -OH_NN_ReturnCode SetInputTensorData(std::vector &inputTensors, std::map &inputDataMap) { - size_t i = 0; - for (auto iter = inputDataMap.begin(); iter != inputDataMap.end(); iter++) { - void *data = OH_NNTensor_GetDataBuffer(inputTensors[i]); +OH_NN_ReturnCode SetInputTensorData(std::vector &inputTensors, std::vector> &inputDataMap) { + if (inputTensors.size() != inputDataMap.size()) { + OH_LOG_ERROR(LOG_APP, "Input tensors and data map size mismatch"); + return OH_NN_FAILED; + } + for (size_t irIndex = 0; irIndex < inputDataMap.size(); irIndex++) { + void *data = OH_NNTensor_GetDataBuffer(inputTensors[irIndex]); + if (data == nullptr) { + OH_LOG_ERROR(LOG_APP, "OH_NNTensor_GetDataBuffer failed, result is nullptr."); + return OH_NN_FAILED; + } size_t size = 0; - OH_NNTensor_GetSize(inputTensors[i], &size); - if (size != iter->second) { - OH_LOG_ERROR(LOG_APP, "OH_NNTensor_GetSize failed"); + OH_NNTensor_GetSize(inputTensors[irIndex], &size); + if (size != inputDataMap[irIndex].second) { + OH_LOG_ERROR(LOG_APP, "OH_NNTensor_GetSize failed %{public}zu %{public}zu %{public}zu", irIndex, size, inputDataMap[irIndex].second); return OH_NN_FAILED; } - memcpy(data, iter->first, iter->second); - i++; + memcpy(data, inputDataMap[irIndex].first, inputDataMap[irIndex].second); } + OH_LOG_INFO(LOG_APP, "SetInputTensorData success"); return OH_NN_SUCCESS; } @@ -152,7 +159,6 @@ OH_NN_ReturnCode HIAIModelManager::LoadModelFromBuffer(uint8_t *modelData, size_ OH_NNCompilation_Destroy(&compilation); return OH_NN_FAILED; } - OH_NNCompilation_Destroy(&compilation); OH_LOG_INFO(LOG_APP, "LoadModelFromBuffer success"); gettimeofday(&tEnd, nullptr); @@ -163,7 +169,11 @@ OH_NN_ReturnCode HIAIModelManager::LoadModelFromBuffer(uint8_t *modelData, size_ } OH_NN_ReturnCode HIAIModelManager::InitLabels(void *data, size_t size) { - inputDataMap_[data] = size; + if (data == nullptr) { + OH_LOG_INFO(LOG_APP, "Init labels fail, data is nullptr."); + return OH_NN_SUCCESS; + } + inputDataMap_.emplace_back(data, size); OH_LOG_INFO(LOG_APP, "Init labels success"); return OH_NN_SUCCESS; } @@ -255,8 +265,8 @@ OH_NN_ReturnCode HIAIModelManager::RunModel(double &avgRunSyncTime) { return OH_NN_SUCCESS; } -std::vector> HIAIModelManager::GetResult() { - std::vector> outputs; +std::vector> HIAIModelManager::GetResult() { + std::vector> outputs; for (auto tensor : outputTensors_) { void *tensorData = OH_NNTensor_GetDataBuffer(tensor); if (tensorData == nullptr) { @@ -270,23 +280,35 @@ std::vector> HIAIModelManager::GetResult() { break; } - float *outputResult = static_cast(tensorData); - int floatSize = 4; - std::vector output(size / floatSize, 0.0); - for (size_t i = 0; i < size / floatSize; ++i) { - output[i] = outputResult[i]; - } - outputs.push_back(output); + outputs.emplace_back(tensorData, size); } if (outputs.size() != outputTensors_.size()) { OH_LOG_ERROR(LOG_APP, "output size mismatch"); outputs.clear(); } - OH_LOG_INFO(LOG_APP, "GetResult success %{public}zu", outputTensors_.size()); + OH_LOG_INFO(LOG_APP, "GetResult success %{public}zu", outputTensors_.size()); return outputs; } +std::vector HIAIModelManager::GetDataType() { + std::vector dataTypes; + for (auto tensor : outputTensors_) { + OH_NN_DataType dataType; + OH_NN_ReturnCode ret = OH_NNTensorDesc_GetDataType(OH_NNTensor_GetTensorDesc(tensor), &dataType); + if (ret != OH_NN_SUCCESS) { + OH_LOG_ERROR(LOG_APP, "Failed to get tensor data type."); + break; + } + dataTypes.push_back(dataType); + } + if (dataTypes.size() != outputTensors_.size()) { + OH_LOG_ERROR(LOG_APP, "data type size mismatch"); + dataTypes.clear(); + } + return dataTypes; +} + OH_NN_ReturnCode HIAIModelManager::UnloadModel(double &avgUnLoadTime) { struct timeval tStart {}; struct timeval tEnd {}; @@ -296,7 +318,9 @@ OH_NN_ReturnCode HIAIModelManager::UnloadModel(double &avgUnLoadTime) { inputTensors_.clear(); DestroyTensors(outputTensors_); outputTensors_.clear(); - OH_NNExecutor_Destroy(&executor_); + if (executor_ != nullptr) { + OH_NNExecutor_Destroy(&executor_); + } inputDataMap_.clear(); gettimeofday(&tEnd, nullptr);