Spaces:
Running
Running
| import React, { useState, useEffect, useRef } from 'react'; | |
| import * as d3 from 'd3'; | |
| import { useTheme } from '../context/themeContext'; | |
| import MODELS from '../utils/models'; | |
| import DEVICES from '../utils/devices'; | |
| type Precision = '32-bit' | '16-bit' | '8-bit' | '4-bit'; | |
| interface ModelSizeBarChartProps { | |
| modelSize: number; // in GB | |
| largestModelSize: number; // largest model in full precision (32-bit) | |
| modelPrecision: Precision; | |
| deviceMemorySet: boolean; | |
| activationMemorySize?: number; | |
| } | |
| interface InferenceRuntimeLineChartProps { | |
| availableMemory: AvailableMemory; // in GB | |
| memoryPerInput: number; // in GB | |
| } | |
| interface LineChartData { | |
| seqLength: number; | |
| batchSize: number; | |
| } | |
| interface AvailableMemory { | |
| '4-bit': number; | |
| '8-bit': number; | |
| '16-bit': number; | |
| '32-bit': number; | |
| } | |
| // Utility to determine color based on precision | |
| function chooseColor(precision: Precision) { | |
| const colors = { | |
| '32-bit': '#e45f5b', | |
| '16-bit': '#ffc068', | |
| '8-bit': '#71cce9', | |
| '4-bit': '#383d95', | |
| }; | |
| return colors[precision] || 'gray'; | |
| } | |
| // Calculate standard memory (model size based on precision only) | |
| function calculateStandardMemory(modelParams: number, precision: Precision): number { | |
| const precisionFactor = { | |
| '32-bit': 4, | |
| '16-bit': 2, | |
| '8-bit': 1, | |
| '4-bit': 0.5, | |
| }; | |
| const memory = modelParams * precisionFactor[precision]; // GB | |
| console.log(`[Standard] ${precision.toUpperCase()} Memory:`, memory); | |
| return memory; | |
| } | |
| // Calculate prefill chunking memory (model size + activation + input memory) | |
| function calculatePrefillMemory( | |
| modelParams: number, | |
| hiddenSize: number, | |
| numLayers: number, | |
| intermediateSize: number, | |
| precision: Precision | |
| ): number { | |
| const precisionFactor = { | |
| '32-bit': 4, | |
| '16-bit': 2, | |
| '8-bit': 1, | |
| '4-bit': 0.5, | |
| }; | |
| // Max Chunk Size - adjustable in the future | |
| const maxChunkSize = 512; | |
| // Calculate each memory component | |
| const modelMemorySize = modelParams * precisionFactor[precision]; // GB | |
| const activationMemorySize = (maxChunkSize * 2 * Math.max(2 * intermediateSize, 4 * hiddenSize)) / 1_000_000_000; // GB | |
| const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB | |
| // Combine all components | |
| const totalMemory = modelMemorySize + activationMemorySize + memoryPerInput; | |
| console.log(`[Prefill] ${precision.toUpperCase()} Memory:`, totalMemory); | |
| console.log(`[Prefill] Activation Memory:`, activationMemorySize); | |
| console.log(`[Prefill] Memory Per Input:`, memoryPerInput); | |
| return totalMemory; | |
| } | |
| // Bar chart for model footprint (shared by both standard and prefill chunking calculators) | |
| function ModelSizeBarChart({ | |
| modelSize, | |
| largestModelSize, | |
| modelPrecision, | |
| deviceMemorySet, | |
| activationMemorySize = 0, | |
| }: ModelSizeBarChartProps) { | |
| const { theme } = useTheme(); | |
| const chartRef = useRef<SVGSVGElement>(null); | |
| const width = 600; | |
| const height = 50; | |
| useEffect(() => { | |
| if (modelSize > 0 && largestModelSize > 0) { | |
| d3.select(chartRef.current).selectAll('*').remove(); | |
| const svg = d3.select(chartRef.current) | |
| .attr('width', width) | |
| .attr('height', height) | |
| .style('animation', 'fadeIn 0.3s ease-in-out') // Inline animation | |
| .style('transition', 'transform 0.3s ease-in-out') // Hover effect | |
| .on('mouseover', function () { | |
| d3.select(this).style('transform', 'scale(1.02)'); | |
| }) | |
| .on('mouseout', function () { | |
| d3.select(this).style('transform', 'scale(1)'); | |
| }); | |
| const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]); | |
| if (modelSize + activationMemorySize > largestModelSize) { | |
| svg | |
| .append('rect') | |
| .attr('x', 0) | |
| .attr('y', 0) | |
| .attr('width', width) | |
| .attr('height', height) | |
| .attr('fill', 'transparent') | |
| .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .style('stroke-dasharray', '4, 4') | |
| .style('stroke-width', '2px'); | |
| svg | |
| .append('text') | |
| .attr('x', width / 2) | |
| .attr('y', height / 2) | |
| .attr('text-anchor', 'middle') | |
| .attr('alignment-baseline', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Out of Memory'); | |
| } else { | |
| svg | |
| .append('rect') | |
| .attr('x', 0) | |
| .attr('y', 0) | |
| .attr('width', xScale(modelSize)) | |
| .attr('height', height) | |
| .attr('fill', chooseColor(modelPrecision)); | |
| if (activationMemorySize > 0) { | |
| svg | |
| .append('rect') | |
| .attr('x', xScale(modelSize)) | |
| .attr('y', 0) | |
| .attr('width', xScale(activationMemorySize)) | |
| .attr('height', height) | |
| .attr('fill', '#a4b8e0'); | |
| } | |
| if (deviceMemorySet) { | |
| svg | |
| .append('rect') | |
| .attr('x', xScale(modelSize + activationMemorySize)) | |
| .attr('y', 0) | |
| .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize))) | |
| .attr('height', height) | |
| .attr('fill', 'transparent') | |
| .style('stroke', chooseColor(modelPrecision)) | |
| .style('stroke-width', '2px'); | |
| } | |
| } | |
| } | |
| }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]); | |
| return <svg ref={chartRef}></svg>; | |
| } | |
| // Line chart for inference runtime (shared by both standard and prefill chunking calculators) | |
| function InferenceRuntimeLineChart({ availableMemory, memoryPerInput }: InferenceRuntimeLineChartProps) { | |
| const { theme } = useTheme(); | |
| const chartRef = useRef(null); | |
| const tooltipRef = useRef<HTMLDivElement>(null); // Ref for the tooltip | |
| const maxSeqLength = 4096; | |
| const maxBatchSize = 128; | |
| useEffect(() => { | |
| if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) { | |
| const margin = { top: 20, right: 20, bottom: 50, left: 50 }; | |
| const width = 600 - margin.left - margin.right; | |
| const height = 400 - margin.top - margin.bottom; | |
| const svg = d3.select(chartRef.current); | |
| svg.selectAll('*').remove(); | |
| const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]); | |
| const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]); | |
| const xAxis = d3.axisBottom(xScale); | |
| const yAxis = d3.axisLeft(yScale); | |
| const zoom = d3.zoom() | |
| .scaleExtent([0.5, 10]) | |
| .translateExtent([[-width, -height], [2 * width, 2 * height]]) | |
| .on('zoom', (event) => { | |
| const transform = event.transform; | |
| svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale))); | |
| svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale))); | |
| svg.selectAll('path').attr('transform', transform); | |
| }); | |
| svg | |
| .attr('width', width + margin.left + margin.right) | |
| .attr('height', height + margin.top + margin.bottom) | |
| .append('g') | |
| .attr('transform', `translate(${margin.left}, ${margin.top})`) | |
| .call(zoom); | |
| svg.append('g').attr('class', 'x-axis').attr('transform', `translate(${margin.left}, ${height + margin.top})`).call(xAxis); | |
| svg.append('g').attr('class', 'y-axis').attr('transform', `translate(${margin.left}, ${margin.top})`).call(yAxis); | |
| svg.append('text') | |
| .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) | |
| .style('text-anchor', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Sequence Length'); | |
| svg.append('text') | |
| .attr('transform', `rotate(-90)`) | |
| .attr('y', 0) | |
| .attr('x', 0 - height / 2 - margin.top) | |
| .attr('dy', '1em') | |
| .style('text-anchor', 'middle') | |
| .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .text('Batch Size'); | |
| // Adding legend for precisions | |
| const precisions = [ | |
| { name: '32-bit', color: '#e45f5b' }, | |
| { name: '16-bit', color: '#ffc068' }, | |
| { name: '8-bit', color: '#71cce9' }, | |
| { name: '4-bit', color: '#383d95' }, | |
| ]; | |
| const legend = svg | |
| .append('g') | |
| .attr('class', 'legend') | |
| .attr('transform', `translate(${width - 20}, 20)`); | |
| precisions.forEach((precision, index) => { | |
| const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`); | |
| legendItem.append('rect') | |
| .attr('x', 10) | |
| .attr('y', 10) | |
| .attr('width', 10) | |
| .attr('height', 10) | |
| .style('fill', precision.color); | |
| legendItem.append('text') | |
| .attr('x', 30) | |
| .attr('y', 16) | |
| .text(precision.name) | |
| .style('font-size', '16px') | |
| .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') | |
| .attr('alignment-baseline', 'middle'); | |
| }); | |
| legend.append('rect') | |
| .attr('class', 'legend-box') | |
| .attr('width', 80) | |
| .attr('height', precisions.length * 30) | |
| .style('fill', 'none') | |
| .style('stroke-width', '1px') | |
| .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26'); | |
| const tooltip = d3.select(tooltipRef.current) | |
| .style('position', 'absolute') | |
| .style('padding', '8px') | |
| .style('border-radius', '4px') | |
| .style('pointer-events', 'none') | |
| .style('opacity', 0) | |
| .style('transition', 'opacity 0.3s ease-in-out, transform 0.3s ease-in-out') | |
| .style('background-color', 'rgba(0, 0, 0, 0.75)') | |
| .style('color', 'white') | |
| .style('font-size', '14px'); | |
| for (const [precision, memory] of Object.entries(availableMemory)) { | |
| const sequenceLengths = d3.range(1, maxSeqLength, 1) | |
| .map((seqLength) => ({ | |
| seqLength, | |
| batchSize: memory / (seqLength * memoryPerInput), | |
| })) | |
| .filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1); | |
| const lineGroup = svg.append('g').attr('transform', `translate(${margin.left}, ${margin.top})`); | |
| const line = d3.line<LineChartData>() | |
| .x((d) => xScale(d.seqLength)) | |
| .y((d) => yScale(d.batchSize)) | |
| .curve(d3.curveBasis); | |
| lineGroup.append('path') | |
| .datum(sequenceLengths) | |
| .attr('fill', 'none') | |
| .attr('stroke', chooseColor(precision as Precision)) | |
| .attr('stroke-width', 4) | |
| .attr('d', line) | |
| .on('mouseover', () => { | |
| tooltip.style('opacity', 1) | |
| .style('transform', 'translateY(-10px)'); | |
| }) | |
| .on('mousemove', (event) => { | |
| tooltip.selectAll('text').remove(); | |
| const [x, y] = d3.pointer(event); | |
| const xValue = xScale.invert(x); | |
| const yValue = yScale.invert(y); | |
| tooltip.html(`Sequence Length: ${xValue.toFixed(0)}<br/>Batch Size: ${yValue.toFixed(0)}`) | |
| .style('left', event.pageX + 10 + 'px') | |
| .style('top', event.pageY + 10 + 'px'); | |
| }) | |
| .on('mouseout', () => { | |
| tooltip.style('opacity', 0); | |
| }); | |
| } | |
| } | |
| }, [availableMemory, memoryPerInput, theme]); | |
| return ( | |
| <> | |
| <div id="tooltip" ref={tooltipRef}></div> | |
| <svg ref={chartRef} width={600} height={400} /> | |
| </> | |
| ); | |
| } | |
| // Prefill Chunking Calculator with Updated Logic and Precision Adjustment | |
| function PrefillChunkingCalculator({ | |
| deviceMemory, | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| intermediateSize, | |
| }: { | |
| deviceMemory: number; | |
| modelParams: number; | |
| hiddenSize: number; | |
| numLayers: number; | |
| intermediateSize: number; | |
| }) { | |
| if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) { | |
| return null; | |
| } | |
| // Calculate activation memory size based on intermediate size and hidden size | |
| const activationMemorySize = (512 * 2 * (Math.max(2 * intermediateSize, 4 * hiddenSize))) / 1_000_000_000; | |
| return ( | |
| <> | |
| {/* Model Footprint with Prefill Chunking */} | |
| <div className="chart"> | |
| <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">Model Footprint with Prefill Chunking</div> | |
| <div className="space-y-8"> | |
| {(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => { | |
| const totalMemory = calculatePrefillMemory( | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| intermediateSize, | |
| precision | |
| ); | |
| return ( | |
| <div key={precision} style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="chart-row"> | |
| <div className="chart-row-title">{precision.toUpperCase()}</div> | |
| <ModelSizeBarChart | |
| modelSize={totalMemory} | |
| largestModelSize={deviceMemory} | |
| modelPrecision={precision} | |
| deviceMemorySet={deviceMemory > 0} | |
| activationMemorySize={activationMemorySize} // Updated to pass activation memory size | |
| /> | |
| <div className="chart-row-size ml-8"> | |
| {totalMemory.toFixed(2)} / {deviceMemory} GB | |
| </div> | |
| </div> | |
| ); | |
| })} | |
| </div> | |
| </div> | |
| {/* Inference Runtime with Prefill Chunking */} | |
| <div className="chart"> | |
| <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4"> | |
| Maximum Batch Size / Sequence Length with Prefill Chunking | |
| </div> | |
| <InferenceRuntimeLineChart | |
| availableMemory={{ | |
| '4-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '4-bit'), | |
| '8-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '8-bit'), | |
| '16-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '16-bit'), | |
| '32-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '32-bit'), | |
| }} | |
| memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000} | |
| /> | |
| </div> | |
| </> | |
| ); | |
| } | |
| // Standard Model Memory Calculator (unchanged) | |
| function StandardCalculator({ | |
| deviceMemory, | |
| modelParams, | |
| hiddenSize, | |
| numLayers, | |
| }: { | |
| deviceMemory: number; | |
| modelParams: number; | |
| hiddenSize: number; | |
| numLayers: number; | |
| }) { | |
| if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) { | |
| return null; | |
| } | |
| return ( | |
| <> | |
| {/* Model Footprint */} | |
| <div className="chart mb-8"> | |
| <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">Model Footprint</div> | |
| <div className="space-y-8"> | |
| {(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => ( | |
| <div key={precision} style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="chart-row"> | |
| <div className="chart-row-title">{precision.toUpperCase()}</div> | |
| <ModelSizeBarChart | |
| modelSize={calculateStandardMemory(modelParams, precision)} | |
| largestModelSize={deviceMemory} | |
| modelPrecision={precision} | |
| deviceMemorySet={deviceMemory > 0} | |
| /> | |
| <div className="chart-row-size ml-8"> | |
| {calculateStandardMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB | |
| </div> | |
| </div> | |
| ))} | |
| </div> | |
| </div> | |
| {/* Maximum Batch Size / Sequence Length */} | |
| <div className="chart"> | |
| <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4"> | |
| Maximum Batch Size / Sequence Length | |
| </div> | |
| <InferenceRuntimeLineChart | |
| availableMemory={{ | |
| '4-bit': deviceMemory - calculateStandardMemory(modelParams, '4-bit'), | |
| '8-bit': deviceMemory - calculateStandardMemory(modelParams, '8-bit'), | |
| '16-bit': deviceMemory - calculateStandardMemory(modelParams, '16-bit'), | |
| '32-bit': deviceMemory - calculateStandardMemory(modelParams, '32-bit'), | |
| }} | |
| memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000} | |
| /> | |
| </div> | |
| </> | |
| ); | |
| } | |
| // Main Calculator Page | |
| const Calculator = () => { | |
| const [modelParams, setModelParams] = useState<number | null>(null); | |
| const [hiddenSize, setHiddenSize] = useState<number | null>(null); | |
| const [numLayers, setNumLayers] = useState<number | null>(null); | |
| const [intermediateSize, setIntermediateSize] = useState<number | null>(null); | |
| const [deviceMemory, setDeviceMemory] = useState<number | null>(null); | |
| const [isPrefillChunking, setIsPrefillChunking] = useState<boolean>(false); | |
| const [modelSelectionTab, setModelSelectionTab] = useState<boolean>(true); | |
| const [deviceSelectionTab, setDeviceSelectionTab] = useState<boolean>(true); | |
| return ( | |
| <div className="flex flex-col items-center justify-center min-h-screen px-4"> | |
| {/* Toggle Between Standard and Prefill Chunking */} | |
| <div className="mb-4 flex space-x-4"> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${!isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setIsPrefillChunking(false)} | |
| > | |
| Standard Calculator | |
| </button> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setIsPrefillChunking(true)} | |
| > | |
| Calculator with Prefill Chunking | |
| </button> | |
| </div> | |
| {/* Model and Device Selection */} | |
| <div className="w-full max-w-4xl"> | |
| <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-4xl mb-4 text-center">Model Memory Calculator</div> | |
| <div className="mb-6 text-center"> | |
| Use our Model Memory Calculator to help you estimate the memory footprint of your model | |
| and the maximum batch size/sequence length combination you can run on your device. | |
| </div> | |
| <div className="grid grid-cols-1 sm:grid-cols-2 gap-4 mb-6"> | |
| {/* Model Selection */} | |
| <div className="calculator-input-box"> | |
| <div className="text-2xl calculator-input-title">Model</div> | |
| <div className="calculator-input-content"> | |
| <div className="mb-2"> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${modelSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => setModelSelectionTab(true)} | |
| > | |
| Model Selection | |
| </button> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${modelSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`} | |
| onClick={() => setModelSelectionTab(false)} | |
| > | |
| Custom Model | |
| </button> | |
| </div> | |
| <div> | |
| {modelSelectionTab ? ( | |
| <> | |
| <label htmlFor="model">Select a Model</label> | |
| <select | |
| id="model" | |
| className="calculator-select" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| onChange={(e) => { | |
| const selectedModel = MODELS.find( | |
| (model) => model.params === Number(e.target.value) | |
| ); | |
| if (selectedModel) { | |
| setModelParams(selectedModel.params); | |
| setHiddenSize(selectedModel.hidden_size); | |
| setNumLayers(selectedModel.num_hidden_layers); | |
| setIntermediateSize(selectedModel.intermediate_size); | |
| } | |
| }} | |
| > | |
| <option value="">None selected</option> | |
| {MODELS.map((model) => ( | |
| <option key={model.name} value={model.params}> | |
| {model.name} | |
| </option> | |
| ))} | |
| </select> | |
| </> | |
| ) : ( | |
| <> | |
| <label htmlFor="modelParams">Model Parameters (in billions)</label> | |
| <input | |
| type="number" | |
| id="modelParams" | |
| className="calculator-input mb-2" | |
| placeholder="e.g. 7 (for LLaMA-7B)" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| value={modelParams || ''} | |
| min={0} | |
| onChange={(e) => setModelParams(Number(e.target.value))} | |
| /> | |
| <label htmlFor="hiddenSize">Hidden Size</label> | |
| <input | |
| type="number" | |
| id="hiddenSize" | |
| className="calculator-input mb-2" | |
| placeholder="e.g. 4096 (for LLaMA-7B)" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| value={hiddenSize || ''} | |
| min={1} | |
| onChange={(e) => setHiddenSize(Number(e.target.value))} | |
| /> | |
| <label htmlFor="numLayers">Number of Layers</label> | |
| <input | |
| type="number" | |
| id="numLayers" | |
| className="calculator-input" | |
| placeholder="e.g. 32 (for LLaMA-7B)" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| value={numLayers || ''} | |
| min={1} | |
| onChange={(e) => setNumLayers(Number(e.target.value))} | |
| /> | |
| {isPrefillChunking && ( | |
| <> | |
| <label htmlFor="intermediateSize">Intermediate Size</label> | |
| <input | |
| type="number" | |
| id="intermediateSize" | |
| className="calculator-input" | |
| placeholder="e.g. 11008 (for LLaMA-7B)" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| value={intermediateSize || ''} | |
| min={1} | |
| onChange={(e) => setIntermediateSize(Number(e.target.value))} | |
| /> | |
| </> | |
| )} | |
| </> | |
| )} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Device Selection */} | |
| <div className="calculator-input-box"> | |
| <div className="text-2xl calculator-input-title">Device</div> | |
| <div className="calculator-input-content"> | |
| <div className="mb-2"> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${deviceSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`} | |
| onClick={() => { | |
| setDeviceSelectionTab(true); | |
| setDeviceMemory(null); | |
| }} | |
| > | |
| Device Selection | |
| </button> | |
| <button | |
| style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }} | |
| className={`${deviceSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`} | |
| onClick={() => { | |
| setDeviceSelectionTab(false); | |
| setDeviceMemory(null); | |
| }} | |
| > | |
| Custom Device | |
| </button> | |
| </div> | |
| <div> | |
| {deviceSelectionTab ? ( | |
| <> | |
| <label htmlFor="device">Select a Device</label> | |
| <select | |
| id="device" | |
| className="calculator-select" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| onChange={(e) => setDeviceMemory(Number(e.target.value))} | |
| > | |
| <option value="">None selected</option> | |
| {DEVICES.map((device) => ( | |
| <option key={device.name} value={device.size}> | |
| {device.name} | |
| </option> | |
| ))} | |
| </select> | |
| </> | |
| ) : ( | |
| <> | |
| <label htmlFor="deviceMemory">Device RAM (in GB)</label> | |
| <input | |
| type="number" | |
| id="deviceMemory" | |
| className="calculator-input" | |
| placeholder="e.g. 24" | |
| style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }} | |
| value={deviceMemory || ''} | |
| min={0} | |
| onChange={(e) => setDeviceMemory(Number(e.target.value))} | |
| /> | |
| </> | |
| )} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Render Appropriate Calculator Based on Toggle */} | |
| {isPrefillChunking ? ( | |
| // eslint-disable-next-line | |
| <PrefillChunkingCalculator | |
| deviceMemory={deviceMemory!} | |
| modelParams={modelParams!} | |
| hiddenSize={hiddenSize!} | |
| numLayers={numLayers!} | |
| intermediateSize={intermediateSize!} | |
| /> | |
| ) : ( | |
| // eslint-disable-next-line | |
| <StandardCalculator | |
| deviceMemory={deviceMemory!} | |
| modelParams={modelParams!} | |
| hiddenSize={hiddenSize!} | |
| numLayers={numLayers!} | |
| /> | |
| )} | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| export default Calculator; | |