Files
ia-gondola-engine/main.py
T

238 lines
9.4 KiB
Python

from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from ultralytics import YOLO
import boto3
import os
import io
import shutil
import yaml
import threading
from PIL import Image
from PIL import Image as PILImage
from datetime import datetime
app = FastAPI()
s3 = boto3.client('s3')
BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
MODELOS_CARREGADOS = {}
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
_status_lock = threading.Lock()
_training_logs: list = []
def _on_fit_epoch_end(trainer):
epoch = trainer.epoch + 1
total = trainer.epochs
try:
losses = [float(x) for x in trainer.loss_items] if getattr(trainer, 'loss_items', None) is not None else []
m = getattr(trainer, 'metrics', None) or {}
map50 = float(m.get('metrics/mAP50(B)', 0))
if len(losses) >= 3:
line = f"[{epoch}/{total}] box={losses[0]:.3f} cls={losses[1]:.3f} dfl={losses[2]:.3f} | mAP50={map50:.4f}"
else:
line = f"[{epoch}/{total}] mAP50={map50:.4f}"
except Exception:
line = f"[{epoch}/{total}]"
with _status_lock:
_training_logs.append(line)
if len(_training_logs) > 60:
_training_logs.pop(0)
def redimensionar_imagem(caminho):
PILImage.MAX_IMAGE_PIXELS = None
img = PILImage.open(caminho)
w, h = img.size
max_px = 4096
if w > max_px or h > max_px:
ratio = min(max_px/w, max_px/h)
novo_w = int(w * ratio)
novo_h = int(h * ratio)
img = img.resize((novo_w, novo_h), PILImage.LANCZOS)
img.save(caminho, quality=95)
def log_print(msg):
carimbo = datetime.now().strftime("%H:%M:%S")
print(f"[{carimbo}] {msg}", flush=True)
def carregar_modelo_do_s3(ambiente: str):
s3_key = f"modelos/{ambiente}/atual/cerebro.pt"
local_path = f"/tmp/cerebro_{ambiente}.pt"
if os.path.exists(local_path):
data_local = os.path.getmtime(local_path)
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
data_s3 = resp_head['LastModified'].timestamp()
if data_s3 > data_local:
log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
os.remove(local_path)
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
if ambiente not in MODELOS_CARREGADOS:
if not os.path.exists(local_path):
log_print(f"Baixando cerebro '{ambiente}'...")
s3.download_file(BUCKET, s3_key, local_path)
MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
return MODELOS_CARREGADOS[ambiente]
def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
"""Execução síncrona do treino — chamada em background thread."""
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
modelo_base = carregar_modelo_do_s3(ambiente)
if pular_triagem:
log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs:
for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except Exception:
continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir = f"{dataset_local}/train/images"
lbl_dir = f"{dataset_local}/train/labels"
if os.path.exists(dataset_local):
shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
log_print("Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
modelo_base.add_callback("on_fit_epoch_end", _on_fit_epoch_end)
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
modelo_base.reset_callbacks()
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
log_print("TREINAMENTO CONCLUIDO!")
return {"status": "sucesso", "versao": carimbo}
log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
def _treinar_bg(ambiente: str, pular_triagem: bool):
global _treinamento_status
with _status_lock:
_training_logs.clear()
try:
resultado = _executar_treino_sync(ambiente, pular_triagem)
with _status_lock:
if resultado["status"] == "sucesso":
_treinamento_status = {
"status": "concluido",
"ambiente": ambiente,
"versao": resultado.get("versao"),
"detalhe": None,
}
else:
_treinamento_status = {
"status": "vazio",
"ambiente": ambiente,
"versao": None,
"detalhe": "Nenhuma imagem passou na triagem",
}
except Exception as e:
log_print(f"Erro no treino: {str(e)}")
with _status_lock:
_treinamento_status = {
"status": "erro",
"ambiente": ambiente,
"versao": None,
"detalhe": str(e)[:500],
}
@app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try:
Image.MAX_IMAGE_PIXELS = None
modelo = carregar_modelo_do_s3(ambiente)
conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25)
deteccoes = [
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
for r in results for b in r.boxes
]
return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e:
log_print(f"Erro deteccao: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar")
async def treinar(dados: dict):
global _treinamento_status
with _status_lock:
if _treinamento_status["status"] == "running":
raise HTTPException(status_code=409, detail="Treinamento já em andamento")
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
return {"status": "iniciado", "ambiente": ambiente}
@app.get("/treinar/status")
async def status_treino():
with _status_lock:
return {**dict(_treinamento_status), "logs": list(_training_logs[-30:])}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)