File size: 3,878 Bytes
f8c42c9 8b3b676 bfaf968 1d22bd8 bfaf968 1d22bd8 bfaf968 7a38188 b1ed689 7f224da f8c42c9 7a38188 d720b4b 383a122 f8c42c9 7a38188 f8c42c9 bfaf968 72519d5 f8c42c9 8b3b676 f8c42c9 7a38188 6ac6471 8b3b676 6ac6471 8b3b676 f8c42c9 b1ed689 7a38188 8b3b676 9674cf0 bfaf968 383a122 bfaf968 383a122 bfaf968 b9991b1 bfaf968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
// OnDeviceService: uses Xenova's transformers.js to run a small causal LM in browser
// Uses ES module import for Xenova's transformers.js
import {pipeline} from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';
/**
* On device llm inference service using transformers.js
* TODO Implement this class!
*/
export class OnDeviceService {
constructor({modelName = '', quantization = 'fp32'} = {}) {
this.modelName = modelName;
this.modelQuantization = quantization;
this._ready = false;
this._model = null;
}
/**
* Load the model into memory to be ready for inference.
* Download the model if not already cached. Cache the model for future use.
*
* @param progressCb
* @returns {Promise<void>}
*/
async load(progressCb) {
console.log(`β¬οΈ Download Model '${this.modelName}'...`);
// Provide a default progress callback if none is given
const defaultProgressCb = (progress) => {
if (progress && typeof progress === 'object') {
if (progress.status) {
console.log(`[Model Loading] ${progress.status}`);
}
if (progress.loaded && progress.total) {
const percent = ((progress.loaded / progress.total) * 100).toFixed(1);
console.log(`[Model Loading] ${percent}% (${progress.loaded}/${progress.total} bytes)`);
}
} else {
console.log(`[Model Loading] Progress:`, progress);
}
};
this._model = await pipeline('text-generation', this.modelName, {
progress_callback: progressCb || defaultProgressCb,
device: 'webgpu', // run on WebGPU if available
dtype: this.modelQuantization, // set model quantization
});
console.log(`β
Model '${this.modelName}' loaded and ready.`);
this._ready = true;
}
/**
* Returns if the model is loaded and ready for inference
* @returns {boolean}
*/
isReady() {
return this._ready;
}
/**
* Perform inference on the on-device model
* TODO Implement inference
*
* @param prompt - The input prompt string
* @param maxNewTokens - Maximum number of new tokens to generate
* @returns {Promise<string>}
*/
async infer(prompt, {maxNewTokens = 50} = {}) {
if (!this._ready || !this._model) {
console.log("model not ready:", this._ready, this._model);
throw new Error('Model not loaded. Call load() first.');
}
console.log("π Running inference on-device for prompt:\n", prompt);
const messages = [
{ role: "user", content: prompt },
];
const output = await this._model(messages, {
max_new_tokens: maxNewTokens,
temperature: 0.2,
});
console.log("β
Completed inference on-device for prompt:\n", prompt);
// take last generated text which corresponds to the model's answer
const generated_output = output[0]?.generated_text;
const text = generated_output[generated_output.length - 1]?.content.trim() || '';
// todo calculate input and output tokens
return {answer: text, stats: {input_tokens: undefined, output_tokens: undefined}};
}
/**
* Update configuration with new values
*
* @param modelName - The name of the model to use
*/
updateConfig({modelName, quantization} = {}) {
if (modelName) this.modelName = modelName;
if (quantization) this.modelQuantization = quantization;
}
/**
* Retrieve the name of the currently loaded model.
*
* @returns {string} - The name of the model as a string.
*/
getModelName(){
return this.modelName;
}
} |