跳到主要内容

Tensorflow 加载模型

2025年02月13日
柏拉文
越努力,越幸运

一、认识


TFJS 中,一个完整的模型通常包含:

  • model.json: 模型结构、权重路径等元数据(数据流图和权重清单文件)

  • group1-shard\*of\*: 二进制权重文件的集合

基于 Service Worker 缓存 TensorFlow 模型文件和相关的 JavaScript 库,并控制缓存的版本, 具体思路如下:

  1. 注册 Service Worker: 通过 navigator.serviceWorker.register 来注册 Service Worker。并注意: service worker 最大的作用域是 worker 所在的位置(换句话说,如果脚本 sw.js 位于 /js/sw.js 中,默认情况下它只能控制 /js/ 下的 URL)。可以使用 Service-Worker-Allowed 标头指定 worker 的最大作用域列表。如果 navigator.serviceWorker.register(workerPath, { scope: "xx"}) 超过最大作用域, 会报错。

  2. 编写 Service Worker: 1. 缓存版本控制:通过 CACHE_VERSIONCACHE_NAME,为缓存命名时加入版本号,确保每次更新资源时能清理旧的缓存并使用新的缓存。每次版本更新时只保留当前版本的缓存,删除旧版本缓存。2. 安装阶段 (install):在 Service Worker 安装时,在 event.waitUntil 回调中, 使用 caches.open 打开缓存,并将指定的文件(model.jsongroup1-shard1of1.binTensorFlow JS 库)通过 cache.addAll 添加到缓存中。 使用 event.waitUntil 确保 install 事件处理完成之前,Service Worker 不会进入 activated 状态,防止安装阶段未完成就进行缓存管理。3. 激活阶段 (activate):在 Service Worker 激活时,在 event.waitUntil 回调中, 从 caches 中清理掉不再需要的旧版本缓存,确保只保留当前版本的缓存,避免浪费存储空间。4. 拦截请求 (fetch):在拦截 fetch 请求时,首先判断请求的 URL 是否匹配需要缓存的文件。如果是,通过 event.respondWith 劫持响应。在 event.respondWith 回调中, 我们通过 caches.match(event.request) 对网络请求里的每个资源与缓存里可获取的等效资源进行匹配,查看缓存中是否有相应的资源, 如果有缓存,则尝试从缓存中加载文件;如果缓存中没有,则通过网络请求文件并将其缓存。

  3. Web Worker 加载模型: 在 Web Worker 独立线程中, 通过 tf.loadGraphModel 来加载模型相关资源, 并添加 fetchFunc 参数来自定义资源请求逻辑。确保 TensorFlow.js 能够首先从缓存中加载模型文件,而不是默认通过网络请求。通过这种方式,可以减少重复的网络请求,提高性能和离线支持。Web Worker 用于在浏览器主线程之外执行 JavaScript,适用于计算密集型任务; Service Worker 主要用于缓存和拦截网络请求,适合离线支持、推送通知等场景。

通过 Service Worker 持久化缓存机制,可以保证在首次加载时从网络加载资源,并在后续加载时优先使用缓存资源,减少网络请求,提升性能和离线可用性,同时通过版本号控制确保更新后的资源被正确加载。

二、实现


2.1 注册 Service Worker

通过 navigator.serviceWorker.register 来注册 Service Worker。并注意: service worker 最大的作用域是 worker 所在的位置(换句话说,如果脚本 sw.js 位于 /js/sw.js 中,默认情况下它只能控制 /js/ 下的 URL)。可以使用 Service-Worker-Allowed 标头指定 worker 的最大作用域列表。如果 navigator.serviceWorker.register(workerPath, { scope: "xx"}) 超过最大作用域, 会报错。

export async function registerServiceWorker(workerPath) {
if (!('serviceWorker' in navigator)) {
console.warn('Service Worker is not supported in this browser.');
return;
}

try {
const registration = await navigator.serviceWorker.register(workerPath);
console.log('Service Worker registered with scope:', registration.scope);
} catch (error) {
console.error(`注册失败:${error}`);
}
}

2.2 编写 Service Worker

基于 Service Worker 缓存 TensorFlow 模型文件和相关的 JavaScript 库,并控制缓存的版本, 具体思路如下:

  1. 缓存版本控制:通过 CACHE_VERSIONCACHE_NAME,为缓存命名时加入版本号,确保每次更新资源时能清理旧的缓存并使用新的缓存。每次版本更新时只保留当前版本的缓存,删除旧版本缓存。

  2. 安装阶段 (install):在 Service Worker 安装时,在 event.waitUntil 回调中, 使用 caches.open 打开缓存,并将指定的文件(model.jsongroup1-shard1of1.binTensorFlow JS 库)通过 cache.addAll 添加到缓存中。 使用 event.waitUntil 确保 install 事件处理完成之前,Service Worker 不会进入 activated 状态,防止安装阶段未完成就进行缓存管理。

  3. 激活阶段 (activate):在 Service Worker 激活时,在 event.waitUntil 回调中, 从 caches 中清理掉不再需要的旧版本缓存,确保只保留当前版本的缓存,避免浪费存储空间。

  4. 拦截请求 (fetch):在拦截 fetch 请求时,首先判断请求的 URL 是否匹配需要缓存的文件。如果是,通过 event.respondWith 劫持响应。在 event.respondWith 回调中, 我们通过 caches.match(event.request) 对网络请求里的每个资源与缓存里可获取的等效资源进行匹配,查看缓存中是否有相应的资源, 如果有缓存,则尝试从缓存中加载文件;如果缓存中没有,则通过网络请求文件并将其缓存。

通过这样的缓存机制,可以保证在首次加载时从网络加载资源,并在后续加载时优先使用缓存资源,减少网络请求,提升性能和离线可用性,同时通过版本号控制确保更新后的资源被正确加载。

const CACHE_VERSION = 'v1'; // 设置缓存的版本号
const CACHE_NAME = `tensorflow-model-cache-${CACHE_VERSION}`; // 缓存名称包含版本号
const cacheUrlList = [
'./facePredictWorker/model.json',
'./facePredictWorker/group1-shard1of1.bin',
'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js'
];

// 安装阶段:缓存模型文件
self.addEventListener('install', event => {
event.waitUntil(
caches.open(CACHE_NAME).then(cache => {
return cache.addAll(cacheUrlList);
})
);
});

// 激活阶段:删除旧的缓存
self.addEventListener('activate', event => {
const cacheWhitelist = [CACHE_NAME]; // 只保留当前版本的缓存
event.waitUntil(
caches.keys().then(cacheNames => {
return Promise.all(
cacheNames
.filter(cacheName => !cacheWhitelist.includes(cacheName)) // 删除不在白名单中的缓存
.map(cacheName => caches.delete(cacheName))
);
})
);
});

function checkIsUseCache(requestUrl) {
const newCacheUrlList = cacheUrlList.map(url => {
if (url.startsWith('./')) {
url = url.replace('./', '');
}
return url;
});

return newCacheUrlList.some(url => requestUrl.includes(url));
}

// 拦截网络请求,使用缓存加载模型
self.addEventListener('fetch', event => {
const requestUrl = event?.request?.url || '';
const isUseCache = checkIsUseCache(requestUrl);

if (isUseCache) {
event.respondWith(
caches.match(event.request).then(cachedResponse => {
// 如果缓存中有请求的文件,直接返回缓存
return cachedResponse || fetch(event.request);
})
);
}
});

2.3 Web Worker 加载模型

Web Worker 独立线程中, 通过 tf.loadGraphModel 来加载模型相关资源, 并添加 fetchFunc 参数来自定义资源请求逻辑。确保 TensorFlow.js 能够首先从缓存中加载模型文件,而不是默认通过网络请求。通过这种方式,可以减少重复的网络请求,提高性能和离线支持。

Web Worker 用于在浏览器主线程之外执行 JavaScript,适用于计算密集型任务; Service Worker 主要用于缓存和拦截网络请求,适合离线支持、推送通知等场景。

model = await tf.loadGraphModel(modelUrl, {
fetchFunc: (url, options) => {
return caches.match(url).then(cachedResponse => {
return cachedResponse || fetch(url, options); // 如果缓存中有响应,则返回缓存的内容,否则继续发起请求
});
}
});