Tensorflow 加载模型
一、认识
在 TFJS
中,一个完整的模型通常包含:
-
model.json
: 模型结构、权重路径等元数据(数据流图和权重清单文件) -
group1-shard\*of\*
: 二进制权重文件的集合
基于 Service Worker
缓存 TensorFlow
模型文件和相关的 JavaScript
库,并控制缓存的版本, 具体思路如下:
-
注册
Service Worker
: 通过navigator.serviceWorker.register
来注册Service Worker
。并注意:service worker
最大的作用域是worker
所在的位置(换句话说,如果脚本sw.js
位于/js/sw.js
中,默认情况下它只能控制/js/
下的URL
)。可以使用Service-Worker-Allowed
标头指定worker
的最大作用域列表。如果navigator.serviceWorker.register(workerPath, { scope: "xx"})
超过最大作用域, 会报错。 -
编写
Service Worker
: 1. 缓存版本控制:通过CACHE_VERSION
和CACHE_NAME
,为缓存命名时加入版本号,确保每次更新资源时能清理旧的缓存并使用新的缓存。每次版本更新时只保留当前版本的缓存,删除旧版本缓存。2. 安装阶段 (install
):在Service Worker
安装时,在event.waitUntil
回调中, 使用caches.open
打开缓存,并将指定的文件(model.json
、group1-shard1of1.bin
和TensorFlow JS
库)通过cache.addAll
添加到缓存中。 使用event.waitUntil
确保install
事件处理完成之前,Service Worker
不会进入activated
状态,防止安装阶段未完成就进行缓存管理。3. 激活阶段 (activate
):在Service Worker
激活时,在event.waitUntil
回调中, 从caches
中清理掉不再需要的旧版本缓存,确保只保留当前版本的缓存,避免浪费存储空间。4. 拦截请求 (fetch
):在拦截fetch
请求时,首先判断请求的URL
是否匹配需要缓存的文件。如果是,通过event.respondWith
劫持响应。在event.respondWith
回调中, 我们通过caches.match(event.request)
对网络请求里的每个资源与缓存里可获取的等效资源进行匹配,查看缓存中是否有相应的资源, 如果有缓存,则尝试从缓存中加载文件;如果缓存中没有,则通过网络请求文件并将其缓存。 -
Web Worker
加载模型: 在Web Worker
独立线程中, 通过tf.loadGraphModel
来加载模型相关资源, 并添加fetchFunc
参数来自定义资源请求逻辑。确保TensorFlow.js
能够首先从缓存中加载模型文件,而不是默认通过网络请求。通过这种方式,可以减少重复的网络请求,提高性能和离线支持。Web Worker
用于在浏览器主线程之外执行JavaScript
,适用于计算密集型任务;Service Worker
主要用于缓存和拦截网络请求,适合离线支持、推送通知等场景。
通过 Service Worker
持久化缓存机制,可以保证在首次加载时从网络加载资源,并在后续加载时优先使用缓存资源,减少网络请求,提升性能和离线可用性,同时通过版本号控制确保更新后的资源被正确加载。
二、实现
2.1 注册 Service Worker
通过 navigator.serviceWorker.register
来注册 Service Worker
。并注意: service worker
最大的作用域是 worker
所在的位置(换句话说,如果脚本 sw.js
位于 /js/sw.js
中,默认情况下它只能控制 /js/
下的 URL
)。可以使用 Service-Worker-Allowed
标头指定 worker
的最大作用域列表。如果 navigator.serviceWorker.register(workerPath, { scope: "xx"})
超过最大作用域, 会报错。
export async function registerServiceWorker(workerPath) {
if (!('serviceWorker' in navigator)) {
console.warn('Service Worker is not supported in this browser.');
return;
}
try {
const registration = await navigator.serviceWorker.register(workerPath);
console.log('Service Worker registered with scope:', registration.scope);
} catch (error) {
console.error(`注册失败:${error}`);
}
}
2.2 编写 Service Worker
基于 Service Worker
缓存 TensorFlow
模型文件和相关的 JavaScript
库,并控制缓存的版本, 具体思路如下:
-
缓存版本控制:通过
CACHE_VERSION
和CACHE_NAME
,为缓存命名时加入版本号,确保每次更新资源时能清理旧的缓存并使用新的缓存。每次版本更新时只保留当前版本的缓存,删除旧版本缓存。 -
安装阶段 (
install
):在Service Worker
安装时,在event.waitUntil
回调中, 使用caches.open
打开缓存,并将指定的文件(model.json
、group1-shard1of1.bin
和TensorFlow JS
库)通过cache.addAll
添加到缓存中。 使用event.waitUntil
确保install
事件处理完成之前,Service Worker
不会进入activated
状态,防止安装阶段未完成就进行缓存管理。 -
激活阶段 (
activate
):在Service Worker
激活时,在event.waitUntil
回调中, 从caches
中清理掉不再需要的旧版本缓存,确保只保留当前版本的缓存,避免浪费存储空间。 -
拦截请求 (
fetch
):在拦截fetch
请求时,首先判断请求的URL
是否匹配需要缓存的文件。如果是,通过event.respondWith
劫持响应。在event.respondWith
回调中, 我们通过caches.match(event.request)
对网络请求里的每个资源与缓存里可获取的等效资源进行匹配,查看缓存中是否有相应的资源, 如果有缓存,则尝试从缓存中加载文件;如果缓存中没有,则通过网络请求文件并将其缓存。
通过这样的缓存机制,可以保证在首次加载时从网络加载资源,并在后续加载时优先使用缓存资源,减少网络请求,提升性能和离线可用性,同时通过版本号控制确保更新后的资源被正确加载。
const CACHE_VERSION = 'v1'; // 设置缓存的版本号
const CACHE_NAME = `tensorflow-model-cache-${CACHE_VERSION}`; // 缓存名称包含版本号
const cacheUrlList = [
'./facePredictWorker/model.json',
'./facePredictWorker/group1-shard1of1.bin',
'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js'
];
// 安装阶段:缓存模型文件
self.addEventListener('install', event => {
event.waitUntil(
caches.open(CACHE_NAME).then(cache => {
return cache.addAll(cacheUrlList);
})
);
});
// 激活阶段:删除旧的缓存
self.addEventListener('activate', event => {
const cacheWhitelist = [CACHE_NAME]; // 只保留当前版本的缓存
event.waitUntil(
caches.keys().then(cacheNames => {
return Promise.all(
cacheNames
.filter(cacheName => !cacheWhitelist.includes(cacheName)) // 删除不在白名单中的缓存
.map(cacheName => caches.delete(cacheName))
);
})
);
});
function checkIsUseCache(requestUrl) {
const newCacheUrlList = cacheUrlList.map(url => {
if (url.startsWith('./')) {
url = url.replace('./', '');
}
return url;
});
return newCacheUrlList.some(url => requestUrl.includes(url));
}
// 拦截网络请求,使用缓存加载模型
self.addEventListener('fetch', event => {
const requestUrl = event?.request?.url || '';
const isUseCache = checkIsUseCache(requestUrl);
if (isUseCache) {
event.respondWith(
caches.match(event.request).then(cachedResponse => {
// 如果缓存中有请求的文件,直接返回缓存
return cachedResponse || fetch(event.request);
})
);
}
});
2.3 Web Worker 加载模型
在 Web Worker
独立线程中, 通过 tf.loadGraphModel
来加载模型相关资源, 并添加 fetchFunc
参数来自定义资源请求逻辑。确保 TensorFlow.js
能够首先从缓存中加载模型文件,而不是默认通过网络请求。通过这种方式,可以减少重复的网络请求,提高性能和离线支持。
Web Worker
用于在浏览器主线程之外执行 JavaScript
,适用于计算密集型任务; Service Worker
主要用于缓存和拦截网络请求,适合离线支持、推送通知等场景。
model = await tf.loadGraphModel(modelUrl, {
fetchFunc: (url, options) => {
return caches.match(url).then(cachedResponse => {
return cachedResponse || fetch(url, options); // 如果缓存中有响应,则返回缓存的内容,否则继续发起请求
});
}
});