浮点16量化
2025年02月09日
一、认识
浮点16
量化(Float16 Quantization
)
二、实现
const fs = require("fs");
const tf = require("@tensorflow/tfjs-node");
async function quantizeModelFloat16(inputModelPath, outputModelPath) {
console.log("🔹 Loading original model...");
const model = await tf.loadGraphModel(`file://${inputModelPath}`);
console.log("🔹 Extracting weightMap...");
const weightMap = model.executor.weightMap;
if (!weightMap || Object.keys(weightMap).length === 0) {
console.error("❌ No weights found in model. Check model.json.");
return;
}
console.log("🔹 Applying float16 quantization...");
const quantizedWeights = [];
let allWeightData = [];
for (const [name, tensors] of Object.entries(weightMap)) {
const tensor = tensors[0]; // 取第一项(通常权重只有一项)
const originalData = tensor.dataSync();
const float16Data = new Uint16Array(originalData.length);
// 进行 float16 量化
for (let i = 0; i < originalData.length; i++) {
float16Data[i] = toFloat16(originalData[i]);
}
quantizedWeights.push({
name,
shape: tensor.shape,
dtype: "float16", // 确保 dtype 是 TF.js 兼容的
offset: allWeightData.reduce((sum, buf) => sum + buf.length, 0),
length: float16Data.length,
});
allWeightData.push(Buffer.from(float16Data.buffer));
}
console.log("🔹 Saving quantized model...");
fs.mkdirSync(outputModelPath, { recursive: true });
// 保存 model.json
const quantizedModel = {
modelTopology: model.modelTopology,
weightSpecs: quantizedWeights,
weightData: [],
};
fs.writeFileSync(
`${outputModelPath}/model.json`,
JSON.stringify(quantizedModel, null, 2)
);
// 保存量化后的权重到 `group1-shard1of1.bin`
const weightBinPath = `${outputModelPath}/group1-shard1of1.bin`;
fs.writeFileSync(weightBinPath, Buffer.concat(allWeightData));
console.log(`✅ Quantized model saved at: ${outputModelPath}`);
}
// Float32 -> Float16 转换函数
function toFloat16(value) {
const floatView = new Float32Array(1);
const int32View = new Uint32Array(floatView.buffer);
floatView[0] = value;
const x = int32View[0];
const b = (x >> 16) & 0x8000; // sign bit
const e = ((x >> 23) & 0xff) - 127 + 15; // exponent
const m = x & 0x7fffff; // mantissa
if (e <= 0) return b; // underflow
if (e >= 31) return b | 0x7c00; // overflow
return b | (e << 10) | (m >> 13);
}
// 运行量化
quantizeModelFloat16("./model.json", "float16Model");