feat: treino assíncrono com polling de status

/treinar agora dispara em background thread e retorna imediato.
Novo endpoint GET /treinar/status expõe estado (idle/running/concluido/vazio/erro).
This commit is contained in:
2026-06-10 14:40:42 -03:00
parent 8bdb28cdc1
commit f7ab4219ca
+139 -87
View File
@@ -5,6 +5,7 @@ import os
import io import io
import shutil import shutil
import yaml import yaml
import threading
from PIL import Image from PIL import Image
from PIL import Image as PILImage from PIL import Image as PILImage
from datetime import datetime from datetime import datetime
@@ -15,6 +16,10 @@ BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
MODELOS_CARREGADOS = {} MODELOS_CARREGADOS = {}
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
_status_lock = threading.Lock()
def redimensionar_imagem(caminho): def redimensionar_imagem(caminho):
PILImage.MAX_IMAGE_PIXELS = None PILImage.MAX_IMAGE_PIXELS = None
img = PILImage.open(caminho) img = PILImage.open(caminho)
@@ -40,18 +45,131 @@ def carregar_modelo_do_s3(ambiente: str):
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key) resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
data_s3 = resp_head['LastModified'].timestamp() data_s3 = resp_head['LastModified'].timestamp()
if data_s3 > data_local: if data_s3 > data_local:
log_print(f"🔄 Modelo '{ambiente}' atualizado no S3. Atualizando local...") log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
os.remove(local_path) os.remove(local_path)
if ambiente in MODELOS_CARREGADOS: if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente] del MODELOS_CARREGADOS[ambiente]
if ambiente not in MODELOS_CARREGADOS: if ambiente not in MODELOS_CARREGADOS:
if not os.path.exists(local_path): if not os.path.exists(local_path):
log_print(f"📥 Baixando cérebro '{ambiente}'...") log_print(f"Baixando cerebro '{ambiente}'...")
s3.download_file(BUCKET, s3_key, local_path) s3.download_file(BUCKET, s3_key, local_path)
MODELOS_CARREGADOS[ambiente] = YOLO(local_path) MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
return MODELOS_CARREGADOS[ambiente] return MODELOS_CARREGADOS[ambiente]
def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
"""Execução síncrona do treino — chamada em background thread."""
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
modelo_base = carregar_modelo_do_s3(ambiente)
if pular_triagem:
log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs:
for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except Exception:
continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir = f"{dataset_local}/train/images"
lbl_dir = f"{dataset_local}/train/labels"
if os.path.exists(dataset_local):
shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
log_print("Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
log_print("TREINAMENTO CONCLUIDO!")
return {"status": "sucesso", "versao": carimbo}
log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
def _treinar_bg(ambiente: str, pular_triagem: bool):
global _treinamento_status
try:
resultado = _executar_treino_sync(ambiente, pular_triagem)
with _status_lock:
if resultado["status"] == "sucesso":
_treinamento_status = {
"status": "concluido",
"ambiente": ambiente,
"versao": resultado.get("versao"),
"detalhe": None,
}
else:
_treinamento_status = {
"status": "vazio",
"ambiente": ambiente,
"versao": None,
"detalhe": "Nenhuma imagem passou na triagem",
}
except Exception as e:
log_print(f"Erro no treino: {str(e)}")
with _status_lock:
_treinamento_status = {
"status": "erro",
"ambiente": ambiente,
"versao": None,
"detalhe": str(e)[:500],
}
@app.post("/detectar") @app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)): async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try: try:
@@ -60,101 +178,35 @@ async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
conteudo = await file.read() conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo)) imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25) results = modelo(imagem, conf=0.25)
deteccoes = [{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)} for r in results for b in r.boxes] deteccoes = [
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
for r in results for b in r.boxes
]
return {"status": "sucesso", "deteccoes": deteccoes} return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e: except Exception as e:
log_print(f"Erro detecção: {str(e)}") log_print(f"Erro deteccao: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar") @app.post("/treinar")
async def treinar(dados: dict): async def treinar(dados: dict):
ambiente = dados.get("ambiente", "gondola") global _treinamento_status
pular_triagem = dados.get("pular_triagem", False) with _status_lock:
try: if _treinamento_status["status"] == "running":
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/" raise HTTPException(status_code=409, detail="Treinamento já em andamento")
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos) ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
modelo_base = carregar_modelo_do_s3(ambiente) threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
return {"status": "iniciado", "ambiente": ambiente}
if pular_triagem:
log_print(f"🚀 TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"🚀 INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs: @app.get("/treinar/status")
for obj in objs['Contents']: async def status_treino():
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')): with _status_lock:
img_key = obj['Key'] return dict(_treinamento_status)
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
# RÉGUA AJUSTADA PARA 0.30 PARA ACEITAR MAIS IMAGENS
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"⚖️ {os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except: continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir, lbl_dir = f"{dataset_local}/train/images", f"{dataset_local}/train/labels"
if os.path.exists(dataset_local): shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True); os.makedirs(lbl_dir, exist_ok=True)
log_print("📥 Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"🏋️ Treinando com {len(os.listdir(img_dir))} fotos. Foco no mAP50...")
# O YOLO imprimirá o mAP automaticamente no log do Docker aqui:
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS: del MODELOS_CARREGADOS[ambiente]
log_print("✅ TREINAMENTO CONCLUÍDO!")
return {"status": "sucesso", "versao": carimbo}
else:
log_print("⚠️ Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
except Exception as e:
log_print(f"💥 Erro: {str(e)}")
return {"status": "erro", "detalhe": str(e)}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn