File size: 3,878 Bytes
f8c42c9
 
8b3b676
bfaf968
 
 
 
 
 
 
1d22bd8
bfaf968
1d22bd8
bfaf968
 
 
 
 
 
 
 
 
 
 
 
 
7a38188
b1ed689
 
 
 
 
 
 
 
 
 
 
 
 
 
7f224da
f8c42c9
7a38188
d720b4b
383a122
f8c42c9
7a38188
f8c42c9
bfaf968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72519d5
f8c42c9
8b3b676
f8c42c9
 
7a38188
6ac6471
8b3b676
 
 
 
 
6ac6471
8b3b676
f8c42c9
b1ed689
7a38188
 
8b3b676
 
 
9674cf0
 
 
bfaf968
 
 
 
 
 
 
383a122
bfaf968
383a122
bfaf968
b9991b1
 
 
 
 
 
 
 
 
 
bfaf968
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// OnDeviceService: uses Xenova's transformers.js to run a small causal LM in browser
// Uses ES module import for Xenova's transformers.js
import {pipeline} from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';


/**
 * On device llm inference service using transformers.js
 * TODO Implement this class!
 */
export class OnDeviceService {
    constructor({modelName = '', quantization = 'fp32'} = {}) {
        this.modelName = modelName;
        this.modelQuantization = quantization;
        this._ready = false;
        this._model = null;
    }


    /**
     * Load the model into memory to be ready for inference.
     * Download the model if not already cached. Cache the model for future use.
     *
     * @param progressCb
     * @returns {Promise<void>}
     */
    async load(progressCb) {
        console.log(`⬇️ Download Model '${this.modelName}'...`);
        // Provide a default progress callback if none is given
        const defaultProgressCb = (progress) => {
            if (progress && typeof progress === 'object') {
                if (progress.status) {
                    console.log(`[Model Loading] ${progress.status}`);
                }
                if (progress.loaded && progress.total) {
                    const percent = ((progress.loaded / progress.total) * 100).toFixed(1);
                    console.log(`[Model Loading] ${percent}% (${progress.loaded}/${progress.total} bytes)`);
                }
            } else {
                console.log(`[Model Loading] Progress:`, progress);
            }
        };

        this._model = await pipeline('text-generation', this.modelName, {
            progress_callback: progressCb || defaultProgressCb,
            device: 'webgpu', // run on WebGPU if available
            dtype: this.modelQuantization, // set model quantization
        });
        console.log(`βœ… Model '${this.modelName}' loaded and ready.`);
        this._ready = true;
    }


    /**
     * Returns if the model is loaded and ready for inference
     * @returns {boolean}
     */
    isReady() {
        return this._ready;
    }


    /**
     * Perform inference on the on-device model
     * TODO Implement inference
     *
     * @param prompt - The input prompt string
     * @param maxNewTokens - Maximum number of new tokens to generate
     * @returns {Promise<string>}
     */
    async infer(prompt, {maxNewTokens = 50} = {}) {
        if (!this._ready || !this._model) {
            console.log("model not ready:", this._ready, this._model);
            throw new Error('Model not loaded. Call load() first.');
        }
        console.log("πŸ”„ Running inference on-device for prompt:\n", prompt);

        const messages = [
            { role: "user", content: prompt },
        ];

        const output = await this._model(messages, {
            max_new_tokens: maxNewTokens,
            temperature: 0.2,
        });

        console.log("βœ… Completed inference on-device for prompt:\n", prompt);

        // take last generated text which corresponds to the model's answer
        const generated_output = output[0]?.generated_text;
        const text = generated_output[generated_output.length - 1]?.content.trim() || '';

        // todo calculate input and output tokens
        return {answer: text, stats: {input_tokens: undefined, output_tokens: undefined}};
    }

    /**
     * Update configuration with new values
     *
     * @param modelName - The name of the model to use
     */
    updateConfig({modelName, quantization} = {}) {
        if (modelName) this.modelName = modelName;
        if (quantization) this.modelQuantization = quantization;
    }


    /**
     * Retrieve the name of the currently loaded model.
     *
     * @returns {string} - The name of the model as a string.
     */
    getModelName(){
        return this.modelName;
    }
}