Upload 2 files
Browse files- app.py +57 -69
- inference.py +18 -20
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from inference import inference_patch
|
|
| 15 |
from convert import abc2xml, xml2, pdf2img
|
| 16 |
|
| 17 |
|
| 18 |
-
#
|
| 19 |
with open('prompts.txt', 'r') as f:
|
| 20 |
prompts = f.readlines()
|
| 21 |
|
|
@@ -25,12 +25,12 @@ for prompt in prompts:
|
|
| 25 |
parts = prompt.split('_')
|
| 26 |
valid_combinations.add((parts[0], parts[1], parts[2]))
|
| 27 |
|
| 28 |
-
#
|
| 29 |
periods = sorted({p for p, _, _ in valid_combinations})
|
| 30 |
composers = sorted({c for _, c, _ in valid_combinations})
|
| 31 |
instruments = sorted({i for _, _, i in valid_combinations})
|
| 32 |
|
| 33 |
-
#
|
| 34 |
def update_components(period, composer):
|
| 35 |
if not period:
|
| 36 |
return [
|
|
@@ -54,7 +54,7 @@ def update_components(period, composer):
|
|
| 54 |
)
|
| 55 |
]
|
| 56 |
|
| 57 |
-
#
|
| 58 |
class RealtimeStream(TextIOBase):
|
| 59 |
def __init__(self, queue):
|
| 60 |
self.queue = queue
|
|
@@ -81,7 +81,7 @@ def convert_files(abc_content, period, composer, instrumentation):
|
|
| 81 |
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
|
| 82 |
f.write(postprocessed_inst_abc)
|
| 83 |
|
| 84 |
-
#
|
| 85 |
file_paths = {'abc': abc_filename}
|
| 86 |
try:
|
| 87 |
# abc2xml
|
|
@@ -115,15 +115,15 @@ def convert_files(abc_content, period, composer, instrumentation):
|
|
| 115 |
})
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
-
raise gr.Error(f"
|
| 119 |
|
| 120 |
return file_paths
|
| 121 |
|
| 122 |
|
| 123 |
-
#
|
| 124 |
def update_page(direction, data):
|
| 125 |
"""
|
| 126 |
-
data
|
| 127 |
"""
|
| 128 |
if not data:
|
| 129 |
return None, gr.update(interactive=False), gr.update(interactive=False), data
|
|
@@ -134,9 +134,9 @@ def update_page(direction, data):
|
|
| 134 |
data['current_page'] += 1
|
| 135 |
|
| 136 |
current_page_index = data['current_page']
|
| 137 |
-
#
|
| 138 |
new_image = f"{data['base']}_page_{current_page_index+1}.png"
|
| 139 |
-
#
|
| 140 |
prev_btn_state = gr.update(interactive=(current_page_index > 0))
|
| 141 |
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
|
| 142 |
|
|
@@ -146,13 +146,13 @@ def update_page(direction, data):
|
|
| 146 |
@spaces.GPU(duration=600)
|
| 147 |
def generate_music(period, composer, instrumentation):
|
| 148 |
"""
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
1) process_output (
|
| 152 |
-
2) final_output (
|
| 153 |
-
3) pdf_image (
|
| 154 |
-
4) audio_player (mp3
|
| 155 |
-
5) pdf_state (
|
| 156 |
"""
|
| 157 |
# Set a different random seed each time based on current timestamp
|
| 158 |
random_seed = int(time.time()) % 10000
|
|
@@ -175,7 +175,7 @@ def generate_music(period, composer, instrumentation):
|
|
| 175 |
pass
|
| 176 |
|
| 177 |
if (period, composer, instrumentation) not in valid_combinations:
|
| 178 |
-
#
|
| 179 |
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
|
| 180 |
|
| 181 |
output_queue = queue.Queue()
|
|
@@ -186,7 +186,7 @@ def generate_music(period, composer, instrumentation):
|
|
| 186 |
|
| 187 |
def run_inference():
|
| 188 |
try:
|
| 189 |
-
#
|
| 190 |
result = inference_patch(period, composer, instrumentation)
|
| 191 |
result_container.append(result)
|
| 192 |
finally:
|
|
@@ -201,40 +201,40 @@ def generate_music(period, composer, instrumentation):
|
|
| 201 |
audio_file = None
|
| 202 |
pdf_state = None
|
| 203 |
|
| 204 |
-
#
|
| 205 |
while thread.is_alive():
|
| 206 |
try:
|
| 207 |
text = output_queue.get(timeout=0.1)
|
| 208 |
process_output += text
|
| 209 |
-
#
|
| 210 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
| 211 |
except queue.Empty:
|
| 212 |
continue
|
| 213 |
|
| 214 |
-
#
|
| 215 |
while not output_queue.empty():
|
| 216 |
text = output_queue.get()
|
| 217 |
process_output += text
|
| 218 |
|
| 219 |
-
#
|
| 220 |
final_result = result_container[0] if result_container else ""
|
| 221 |
|
| 222 |
-
#
|
| 223 |
final_output_abc = "Converting files..."
|
| 224 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
| 225 |
|
| 226 |
|
| 227 |
-
#
|
| 228 |
try:
|
| 229 |
file_paths = convert_files(final_result, period, composer, instrumentation)
|
| 230 |
final_output_abc = final_result
|
| 231 |
-
#
|
| 232 |
if file_paths['pages'] > 0:
|
| 233 |
pdf_image = f"{file_paths['base']}_page_1.png"
|
| 234 |
audio_file = file_paths['mp3']
|
| 235 |
-
pdf_state = file_paths #
|
| 236 |
|
| 237 |
-
#
|
| 238 |
download_list = []
|
| 239 |
if 'abc' in file_paths and os.path.exists(file_paths['abc']):
|
| 240 |
download_list.append(file_paths['abc'])
|
|
@@ -247,60 +247,60 @@ def generate_music(period, composer, instrumentation):
|
|
| 247 |
if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
|
| 248 |
download_list.append(file_paths['mp3'])
|
| 249 |
except Exception as e:
|
| 250 |
-
#
|
| 251 |
yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
|
| 252 |
return
|
| 253 |
|
| 254 |
-
|
| 255 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
|
| 256 |
|
| 257 |
|
| 258 |
def get_file(file_type, period, composer, instrumentation):
|
| 259 |
"""
|
| 260 |
-
|
| 261 |
"""
|
| 262 |
-
#
|
| 263 |
-
#
|
| 264 |
-
#
|
| 265 |
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
|
| 266 |
if not possible_files:
|
| 267 |
return None
|
| 268 |
-
#
|
| 269 |
possible_files.sort(key=os.path.getmtime)
|
| 270 |
return possible_files[-1]
|
| 271 |
|
| 272 |
|
| 273 |
css = """
|
| 274 |
-
/*
|
| 275 |
button[size="sm"] {
|
| 276 |
padding: 4px 8px !important;
|
| 277 |
margin: 2px !important;
|
| 278 |
min-width: 60px;
|
| 279 |
}
|
| 280 |
|
| 281 |
-
/* PDF
|
| 282 |
#pdf-preview {
|
| 283 |
-
border-radius: 8px; /*
|
| 284 |
-
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /*
|
| 285 |
}
|
| 286 |
|
| 287 |
.page-btn {
|
| 288 |
-
padding: 12px !important; /*
|
| 289 |
-
margin: auto !important; /*
|
| 290 |
}
|
| 291 |
|
| 292 |
-
/*
|
| 293 |
.page-btn:hover {
|
| 294 |
background: #f0f0f0 !important;
|
| 295 |
transform: scale(1.05);
|
| 296 |
}
|
| 297 |
|
| 298 |
-
/*
|
| 299 |
.gr-row {
|
| 300 |
-
gap: 10px !important; /*
|
| 301 |
}
|
| 302 |
|
| 303 |
-
/*
|
| 304 |
.audio-panel {
|
| 305 |
margin-top: 15px !important;
|
| 306 |
max-width: 400px;
|
|
@@ -310,23 +310,13 @@ button[size="sm"] {
|
|
| 310 |
height: 200px !important;
|
| 311 |
}
|
| 312 |
|
| 313 |
-
/*
|
| 314 |
.save-as-row {
|
| 315 |
margin-top: 15px;
|
| 316 |
padding: 10px;
|
| 317 |
border-top: 1px solid #eee;
|
| 318 |
}
|
| 319 |
|
| 320 |
-
.save-as-label {
|
| 321 |
-
font-weight: bold;
|
| 322 |
-
margin-right: 10px;
|
| 323 |
-
align-self: center;
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
.save-buttons {
|
| 327 |
-
gap: 5px; /* 按钮间距 */
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
/* Download files styling */
|
| 331 |
.download-files {
|
| 332 |
margin-top: 15px;
|
|
@@ -339,12 +329,12 @@ button[size="sm"] {
|
|
| 339 |
with gr.Blocks(css=css) as demo:
|
| 340 |
gr.Markdown("## NotaGen")
|
| 341 |
|
| 342 |
-
#
|
| 343 |
pdf_state = gr.State()
|
| 344 |
|
| 345 |
with gr.Column():
|
| 346 |
with gr.Row():
|
| 347 |
-
#
|
| 348 |
with gr.Column():
|
| 349 |
with gr.Row():
|
| 350 |
period_dd = gr.Dropdown(
|
|
@@ -384,18 +374,16 @@ with gr.Blocks(css=css) as demo:
|
|
| 384 |
placeholder="Post-processed ABC scores will be shown here..."
|
| 385 |
)
|
| 386 |
|
| 387 |
-
#
|
| 388 |
audio_player = gr.Audio(
|
| 389 |
label="Audio Preview",
|
| 390 |
format="mp3",
|
| 391 |
interactive=False,
|
| 392 |
-
# container=False,
|
| 393 |
-
# elem_id="audio-preview"
|
| 394 |
)
|
| 395 |
|
| 396 |
-
#
|
| 397 |
with gr.Column():
|
| 398 |
-
#
|
| 399 |
pdf_image = gr.Image(
|
| 400 |
label="Sheet Music Preview",
|
| 401 |
show_label=False,
|
|
@@ -406,7 +394,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 406 |
show_download_button=False
|
| 407 |
)
|
| 408 |
|
| 409 |
-
#
|
| 410 |
with gr.Row():
|
| 411 |
prev_btn = gr.Button(
|
| 412 |
"⬅️ Last Page",
|
|
@@ -430,7 +418,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 430 |
type="filepath" # Make sure this is set to filepath
|
| 431 |
)
|
| 432 |
|
| 433 |
-
#
|
| 434 |
period_dd.change(
|
| 435 |
update_components,
|
| 436 |
inputs=[period_dd, composer_dd],
|
|
@@ -442,26 +430,26 @@ with gr.Blocks(css=css) as demo:
|
|
| 442 |
outputs=[composer_dd, instrument_dd]
|
| 443 |
)
|
| 444 |
|
| 445 |
-
#
|
| 446 |
generate_btn.click(
|
| 447 |
generate_music,
|
| 448 |
inputs=[period_dd, composer_dd, instrument_dd],
|
| 449 |
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
|
| 450 |
)
|
| 451 |
|
| 452 |
-
#
|
| 453 |
prev_signal = gr.Textbox(value="prev", visible=False)
|
| 454 |
next_signal = gr.Textbox(value="next", visible=False)
|
| 455 |
|
| 456 |
prev_btn.click(
|
| 457 |
update_page,
|
| 458 |
-
inputs=[prev_signal, pdf_state], # ✅
|
| 459 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
| 460 |
)
|
| 461 |
|
| 462 |
next_btn.click(
|
| 463 |
update_page,
|
| 464 |
-
inputs=[next_signal, pdf_state], # ✅
|
| 465 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
| 466 |
)
|
| 467 |
|
|
|
|
| 15 |
from convert import abc2xml, xml2, pdf2img
|
| 16 |
|
| 17 |
|
| 18 |
+
# Read prompt combinations
|
| 19 |
with open('prompts.txt', 'r') as f:
|
| 20 |
prompts = f.readlines()
|
| 21 |
|
|
|
|
| 25 |
parts = prompt.split('_')
|
| 26 |
valid_combinations.add((parts[0], parts[1], parts[2]))
|
| 27 |
|
| 28 |
+
# Prepare dropdown options
|
| 29 |
periods = sorted({p for p, _, _ in valid_combinations})
|
| 30 |
composers = sorted({c for _, c, _ in valid_combinations})
|
| 31 |
instruments = sorted({i for _, _, i in valid_combinations})
|
| 32 |
|
| 33 |
+
# Dynamically update composer and instrument dropdown options
|
| 34 |
def update_components(period, composer):
|
| 35 |
if not period:
|
| 36 |
return [
|
|
|
|
| 54 |
)
|
| 55 |
]
|
| 56 |
|
| 57 |
+
# Custom realtime stream for outputting model inference process to frontend
|
| 58 |
class RealtimeStream(TextIOBase):
|
| 59 |
def __init__(self, queue):
|
| 60 |
self.queue = queue
|
|
|
|
| 81 |
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
|
| 82 |
f.write(postprocessed_inst_abc)
|
| 83 |
|
| 84 |
+
# Convert files
|
| 85 |
file_paths = {'abc': abc_filename}
|
| 86 |
try:
|
| 87 |
# abc2xml
|
|
|
|
| 115 |
})
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
+
raise gr.Error(f"File processing failed: {str(e)}")
|
| 119 |
|
| 120 |
return file_paths
|
| 121 |
|
| 122 |
|
| 123 |
+
# Page navigation control function
|
| 124 |
def update_page(direction, data):
|
| 125 |
"""
|
| 126 |
+
data contains three key pieces of information: 'pages', 'current_page', and 'base'
|
| 127 |
"""
|
| 128 |
if not data:
|
| 129 |
return None, gr.update(interactive=False), gr.update(interactive=False), data
|
|
|
|
| 134 |
data['current_page'] += 1
|
| 135 |
|
| 136 |
current_page_index = data['current_page']
|
| 137 |
+
# Update image path
|
| 138 |
new_image = f"{data['base']}_page_{current_page_index+1}.png"
|
| 139 |
+
# When current_page==0, prev_btn is disabled; when current_page==pages-1, next_btn is disabled
|
| 140 |
prev_btn_state = gr.update(interactive=(current_page_index > 0))
|
| 141 |
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
|
| 142 |
|
|
|
|
| 146 |
@spaces.GPU(duration=600)
|
| 147 |
def generate_music(period, composer, instrumentation):
|
| 148 |
"""
|
| 149 |
+
Must ensure each yield returns the same number of values.
|
| 150 |
+
We're preparing to return 5 values, corresponding to:
|
| 151 |
+
1) process_output (intermediate inference information)
|
| 152 |
+
2) final_output (final ABC)
|
| 153 |
+
3) pdf_image (path to the PNG of the first page of the PDF)
|
| 154 |
+
4) audio_player (mp3 path)
|
| 155 |
+
5) pdf_state (state for page navigation)
|
| 156 |
"""
|
| 157 |
# Set a different random seed each time based on current timestamp
|
| 158 |
random_seed = int(time.time()) % 10000
|
|
|
|
| 175 |
pass
|
| 176 |
|
| 177 |
if (period, composer, instrumentation) not in valid_combinations:
|
| 178 |
+
# If the combination is invalid, raise an error
|
| 179 |
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
|
| 180 |
|
| 181 |
output_queue = queue.Queue()
|
|
|
|
| 186 |
|
| 187 |
def run_inference():
|
| 188 |
try:
|
| 189 |
+
# Use downloaded model weights path for inference
|
| 190 |
result = inference_patch(period, composer, instrumentation)
|
| 191 |
result_container.append(result)
|
| 192 |
finally:
|
|
|
|
| 201 |
audio_file = None
|
| 202 |
pdf_state = None
|
| 203 |
|
| 204 |
+
# First continuously read intermediate output
|
| 205 |
while thread.is_alive():
|
| 206 |
try:
|
| 207 |
text = output_queue.get(timeout=0.1)
|
| 208 |
process_output += text
|
| 209 |
+
# No final ABC yet, files not yet converted
|
| 210 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
| 211 |
except queue.Empty:
|
| 212 |
continue
|
| 213 |
|
| 214 |
+
# After thread ends, get all remaining items from the queue
|
| 215 |
while not output_queue.empty():
|
| 216 |
text = output_queue.get()
|
| 217 |
process_output += text
|
| 218 |
|
| 219 |
+
# Final inference result
|
| 220 |
final_result = result_container[0] if result_container else ""
|
| 221 |
|
| 222 |
+
# Display file conversion prompt
|
| 223 |
final_output_abc = "Converting files..."
|
| 224 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
| 225 |
|
| 226 |
|
| 227 |
+
# Convert files
|
| 228 |
try:
|
| 229 |
file_paths = convert_files(final_result, period, composer, instrumentation)
|
| 230 |
final_output_abc = final_result
|
| 231 |
+
# Get the first image and mp3 file
|
| 232 |
if file_paths['pages'] > 0:
|
| 233 |
pdf_image = f"{file_paths['base']}_page_1.png"
|
| 234 |
audio_file = file_paths['mp3']
|
| 235 |
+
pdf_state = file_paths # Directly use the converted information dictionary as state
|
| 236 |
|
| 237 |
+
# Prepare download file list
|
| 238 |
download_list = []
|
| 239 |
if 'abc' in file_paths and os.path.exists(file_paths['abc']):
|
| 240 |
download_list.append(file_paths['abc'])
|
|
|
|
| 247 |
if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
|
| 248 |
download_list.append(file_paths['mp3'])
|
| 249 |
except Exception as e:
|
| 250 |
+
# If conversion fails, return error message to output box
|
| 251 |
yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
|
| 252 |
return
|
| 253 |
|
| 254 |
+
# Final yield with all information - modify here to make component visible
|
| 255 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
|
| 256 |
|
| 257 |
|
| 258 |
def get_file(file_type, period, composer, instrumentation):
|
| 259 |
"""
|
| 260 |
+
Returns the local file of specified type for Gradio download
|
| 261 |
"""
|
| 262 |
+
# Here you actually need to return based on specific file paths saved earlier, simplified for demo
|
| 263 |
+
# If matching by timestamp, you can store all converted files in a directory and get the latest
|
| 264 |
+
# This is just an example:
|
| 265 |
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
|
| 266 |
if not possible_files:
|
| 267 |
return None
|
| 268 |
+
# Simply return the latest
|
| 269 |
possible_files.sort(key=os.path.getmtime)
|
| 270 |
return possible_files[-1]
|
| 271 |
|
| 272 |
|
| 273 |
css = """
|
| 274 |
+
/* Compact button style */
|
| 275 |
button[size="sm"] {
|
| 276 |
padding: 4px 8px !important;
|
| 277 |
margin: 2px !important;
|
| 278 |
min-width: 60px;
|
| 279 |
}
|
| 280 |
|
| 281 |
+
/* PDF preview area */
|
| 282 |
#pdf-preview {
|
| 283 |
+
border-radius: 8px; /* Rounded corners */
|
| 284 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */
|
| 285 |
}
|
| 286 |
|
| 287 |
.page-btn {
|
| 288 |
+
padding: 12px !important; /* Increase clickable area */
|
| 289 |
+
margin: auto !important; /* Vertical center */
|
| 290 |
}
|
| 291 |
|
| 292 |
+
/* Button hover effect */
|
| 293 |
.page-btn:hover {
|
| 294 |
background: #f0f0f0 !important;
|
| 295 |
transform: scale(1.05);
|
| 296 |
}
|
| 297 |
|
| 298 |
+
/* Layout adjustment */
|
| 299 |
.gr-row {
|
| 300 |
+
gap: 10px !important; /* Element spacing */
|
| 301 |
}
|
| 302 |
|
| 303 |
+
/* Audio player */
|
| 304 |
.audio-panel {
|
| 305 |
margin-top: 15px !important;
|
| 306 |
max-width: 400px;
|
|
|
|
| 310 |
height: 200px !important;
|
| 311 |
}
|
| 312 |
|
| 313 |
+
/* Save functionality area */
|
| 314 |
.save-as-row {
|
| 315 |
margin-top: 15px;
|
| 316 |
padding: 10px;
|
| 317 |
border-top: 1px solid #eee;
|
| 318 |
}
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
/* Download files styling */
|
| 321 |
.download-files {
|
| 322 |
margin-top: 15px;
|
|
|
|
| 329 |
with gr.Blocks(css=css) as demo:
|
| 330 |
gr.Markdown("## NotaGen")
|
| 331 |
|
| 332 |
+
# For storing PDF page count, current page and other information
|
| 333 |
pdf_state = gr.State()
|
| 334 |
|
| 335 |
with gr.Column():
|
| 336 |
with gr.Row():
|
| 337 |
+
# Left sidebar
|
| 338 |
with gr.Column():
|
| 339 |
with gr.Row():
|
| 340 |
period_dd = gr.Dropdown(
|
|
|
|
| 374 |
placeholder="Post-processed ABC scores will be shown here..."
|
| 375 |
)
|
| 376 |
|
| 377 |
+
# Audio playback
|
| 378 |
audio_player = gr.Audio(
|
| 379 |
label="Audio Preview",
|
| 380 |
format="mp3",
|
| 381 |
interactive=False,
|
|
|
|
|
|
|
| 382 |
)
|
| 383 |
|
| 384 |
+
# Right sidebar
|
| 385 |
with gr.Column():
|
| 386 |
+
# Image container
|
| 387 |
pdf_image = gr.Image(
|
| 388 |
label="Sheet Music Preview",
|
| 389 |
show_label=False,
|
|
|
|
| 394 |
show_download_button=False
|
| 395 |
)
|
| 396 |
|
| 397 |
+
# Page navigation buttons
|
| 398 |
with gr.Row():
|
| 399 |
prev_btn = gr.Button(
|
| 400 |
"⬅️ Last Page",
|
|
|
|
| 418 |
type="filepath" # Make sure this is set to filepath
|
| 419 |
)
|
| 420 |
|
| 421 |
+
# Dropdown linking
|
| 422 |
period_dd.change(
|
| 423 |
update_components,
|
| 424 |
inputs=[period_dd, composer_dd],
|
|
|
|
| 430 |
outputs=[composer_dd, instrument_dd]
|
| 431 |
)
|
| 432 |
|
| 433 |
+
# Click generate button, note outputs must match each yield in generate_music
|
| 434 |
generate_btn.click(
|
| 435 |
generate_music,
|
| 436 |
inputs=[period_dd, composer_dd, instrument_dd],
|
| 437 |
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
|
| 438 |
)
|
| 439 |
|
| 440 |
+
# Page navigation
|
| 441 |
prev_signal = gr.Textbox(value="prev", visible=False)
|
| 442 |
next_signal = gr.Textbox(value="next", visible=False)
|
| 443 |
|
| 444 |
prev_btn.click(
|
| 445 |
update_page,
|
| 446 |
+
inputs=[prev_signal, pdf_state], # ✅ Use component
|
| 447 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
| 448 |
)
|
| 449 |
|
| 450 |
next_btn.click(
|
| 451 |
update_page,
|
| 452 |
+
inputs=[next_signal, pdf_state], # ✅ Use component
|
| 453 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
| 454 |
)
|
| 455 |
|
inference.py
CHANGED
|
@@ -69,30 +69,30 @@ def download_model_weights():
|
|
| 69 |
|
| 70 |
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
1.
|
| 75 |
-
2.
|
| 76 |
-
3.
|
| 77 |
"""
|
| 78 |
-
#
|
| 79 |
model = model.to(dtype=torch.float16)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
for param in model.parameters():
|
| 83 |
if param.dtype == torch.float32:
|
| 84 |
param.requires_grad = False
|
| 85 |
|
| 86 |
-
#
|
| 87 |
if use_gradient_checkpointing:
|
| 88 |
model.gradient_checkpointing_enable()
|
| 89 |
|
| 90 |
return model
|
| 91 |
|
| 92 |
-
|
| 93 |
model = prepare_model_for_kbit_training(
|
| 94 |
model,
|
| 95 |
-
use_gradient_checkpointing=False
|
| 96 |
)
|
| 97 |
|
| 98 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
@@ -146,19 +146,19 @@ def complete_brackets(s):
|
|
| 146 |
stack = []
|
| 147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
| 148 |
|
| 149 |
-
#
|
| 150 |
for char in s:
|
| 151 |
if char in bracket_map:
|
| 152 |
stack.append(char)
|
| 153 |
elif char in bracket_map.values():
|
| 154 |
-
#
|
| 155 |
for key, value in bracket_map.items():
|
| 156 |
if value == char:
|
| 157 |
if stack and stack[-1] == key:
|
| 158 |
stack.pop()
|
| 159 |
-
break #
|
| 160 |
|
| 161 |
-
#
|
| 162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
| 163 |
return s + completion
|
| 164 |
|
|
@@ -333,26 +333,24 @@ def inference_patch(period, composer, instrumentation):
|
|
| 333 |
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
| 334 |
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
| 335 |
|
| 336 |
-
if len(byte_list) > 102400:
|
| 337 |
failure_flag = True
|
| 338 |
break
|
| 339 |
-
if time.time() - start_time >
|
| 340 |
failure_flag = True
|
| 341 |
break
|
| 342 |
|
| 343 |
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
| 344 |
-
# 做流式切片
|
| 345 |
print('Stream generating...')
|
| 346 |
|
| 347 |
metadata = ''.join(metadata_byte_list)
|
| 348 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
| 349 |
|
| 350 |
if '\n' not in context_tunebody:
|
| 351 |
-
#
|
| 352 |
-
break
|
| 353 |
|
| 354 |
context_tunebody_liness = context_tunebody.split('\n')
|
| 355 |
-
if not context_tunebody.endswith('\n'):
|
| 356 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
| 357 |
else:
|
| 358 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|
|
|
|
| 69 |
|
| 70 |
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
| 71 |
"""
|
| 72 |
+
Prepare model for k-bit training.
|
| 73 |
+
Features include:
|
| 74 |
+
1. Convert model to mixed precision (FP16).
|
| 75 |
+
2. Disable unnecessary gradient computations.
|
| 76 |
+
3. Enable gradient checkpointing (optional).
|
| 77 |
"""
|
| 78 |
+
# Convert model to mixed precision
|
| 79 |
model = model.to(dtype=torch.float16)
|
| 80 |
|
| 81 |
+
# Disable gradients for embedding layers
|
| 82 |
for param in model.parameters():
|
| 83 |
if param.dtype == torch.float32:
|
| 84 |
param.requires_grad = False
|
| 85 |
|
| 86 |
+
# Enable gradient checkpointing
|
| 87 |
if use_gradient_checkpointing:
|
| 88 |
model.gradient_checkpointing_enable()
|
| 89 |
|
| 90 |
return model
|
| 91 |
|
| 92 |
+
|
| 93 |
model = prepare_model_for_kbit_training(
|
| 94 |
model,
|
| 95 |
+
use_gradient_checkpointing=False
|
| 96 |
)
|
| 97 |
|
| 98 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
|
|
| 146 |
stack = []
|
| 147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
| 148 |
|
| 149 |
+
# Iterate through each character, handle bracket matching
|
| 150 |
for char in s:
|
| 151 |
if char in bracket_map:
|
| 152 |
stack.append(char)
|
| 153 |
elif char in bracket_map.values():
|
| 154 |
+
# Find the corresponding left bracket
|
| 155 |
for key, value in bracket_map.items():
|
| 156 |
if value == char:
|
| 157 |
if stack and stack[-1] == key:
|
| 158 |
stack.pop()
|
| 159 |
+
break # Found matching right bracket, process next character
|
| 160 |
|
| 161 |
+
# Complete missing right brackets (in reverse order of remaining left brackets in stack)
|
| 162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
| 163 |
return s + completion
|
| 164 |
|
|
|
|
| 333 |
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
| 334 |
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
| 335 |
|
| 336 |
+
if len(byte_list) > 102400:
|
| 337 |
failure_flag = True
|
| 338 |
break
|
| 339 |
+
if time.time() - start_time > 10 * 60:
|
| 340 |
failure_flag = True
|
| 341 |
break
|
| 342 |
|
| 343 |
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
|
|
|
| 344 |
print('Stream generating...')
|
| 345 |
|
| 346 |
metadata = ''.join(metadata_byte_list)
|
| 347 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
| 348 |
|
| 349 |
if '\n' not in context_tunebody:
|
| 350 |
+
break # Generated content is all metadata, abandon
|
|
|
|
| 351 |
|
| 352 |
context_tunebody_liness = context_tunebody.split('\n')
|
| 353 |
+
if not context_tunebody.endswith('\n'):
|
| 354 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
| 355 |
else:
|
| 356 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|