368 lines
13 KiB
Ruby
368 lines
13 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require "excon"
|
|
require "forwardable"
|
|
|
|
require_relative "../../lib/util"
|
|
|
|
class Prog::Ai::InferenceEndpointReplicaNexus < Prog::Base
|
|
subject_is :inference_endpoint_replica
|
|
|
|
extend Forwardable
|
|
def_delegators :inference_endpoint_replica, :vm, :inference_endpoint, :load_balancer_vm_port
|
|
|
|
def self.assemble(inference_endpoint_id)
|
|
DB.transaction do
|
|
ubid = InferenceEndpointReplica.generate_ubid
|
|
|
|
inference_endpoint = InferenceEndpoint[inference_endpoint_id]
|
|
vm_st = Prog::Vm::Nexus.assemble_with_sshable(
|
|
Config.inference_endpoint_service_project_id,
|
|
sshable_unix_user: "ubi",
|
|
location_id: inference_endpoint.location_id,
|
|
name: ubid.to_s,
|
|
size: inference_endpoint.vm_size,
|
|
storage_volumes: inference_endpoint.storage_volumes.map { it.transform_keys(&:to_sym) },
|
|
boot_image: inference_endpoint.boot_image,
|
|
private_subnet_id: inference_endpoint.load_balancer.private_subnet.id,
|
|
enable_ip4: true,
|
|
gpu_count: inference_endpoint.gpu_count
|
|
)
|
|
|
|
inference_endpoint.load_balancer.add_vm(vm_st.subject)
|
|
|
|
replica = InferenceEndpointReplica.create(
|
|
inference_endpoint_id: inference_endpoint_id,
|
|
vm_id: vm_st.id
|
|
) { it.id = ubid.to_uuid }
|
|
|
|
Strand.create_with_id(replica.id, prog: "Ai::InferenceEndpointReplicaNexus", label: "start")
|
|
end
|
|
end
|
|
|
|
def before_run
|
|
when_destroy_set? do
|
|
if strand.label != "destroy"
|
|
hop_destroy
|
|
elsif strand.stack.count > 1
|
|
pop "operation is cancelled due to the destruction of the inference endpoint replica"
|
|
end
|
|
end
|
|
end
|
|
|
|
label def start
|
|
nap 5 unless vm.strand.label == "wait"
|
|
|
|
hop_bootstrap_rhizome
|
|
end
|
|
|
|
label def bootstrap_rhizome
|
|
register_deadline("wait", 15 * 60)
|
|
|
|
bud Prog::BootstrapRhizome, {"target_folder" => "inference_endpoint", "subject_id" => vm.id, "user" => "ubi"}
|
|
hop_wait_bootstrap_rhizome
|
|
end
|
|
|
|
label def wait_bootstrap_rhizome
|
|
reap(:download_lb_cert)
|
|
end
|
|
|
|
label def download_lb_cert
|
|
vm.sshable.cmd("sudo inference_endpoint/bin/download-lb-cert")
|
|
hop_setup_external
|
|
end
|
|
|
|
label def setup_external
|
|
case inference_endpoint.engine
|
|
when "runpod"
|
|
if inference_endpoint_replica.external_state["pod_id"]
|
|
if (pod = get_runpod_pod) && pod[:ip] && pod[:port]
|
|
inference_endpoint_replica.update(external_state: pod)
|
|
hop_setup
|
|
end
|
|
else
|
|
inference_endpoint_replica.update(external_state: {"pod_id" => create_runpod_pod})
|
|
end
|
|
else
|
|
hop_setup
|
|
end
|
|
|
|
nap 10
|
|
end
|
|
|
|
label def setup
|
|
case vm.sshable.cmd("common/bin/daemonizer --check setup")
|
|
when "Succeeded"
|
|
hop_wait_endpoint_up
|
|
when "Failed", "NotStarted"
|
|
params = {
|
|
engine_start_cmd:,
|
|
replica_ubid: inference_endpoint_replica.ubid,
|
|
ssl_crt_path: "/ie/workdir/ssl/ubi_cert.pem",
|
|
ssl_key_path: "/ie/workdir/ssl/ubi_key.pem",
|
|
gateway_port: inference_endpoint.load_balancer.ports.first.dst_port,
|
|
max_requests: inference_endpoint.max_requests
|
|
}
|
|
params_json = JSON.generate(params)
|
|
vm.sshable.cmd("common/bin/daemonizer 'sudo inference_endpoint/bin/setup-replica' setup", stdin: params_json)
|
|
end
|
|
|
|
nap 5
|
|
end
|
|
|
|
label def wait_endpoint_up
|
|
hop_wait if available?
|
|
|
|
nap 5
|
|
end
|
|
|
|
label def wait
|
|
hop_unavailable unless available?
|
|
ping_gateway
|
|
|
|
nap 120
|
|
end
|
|
|
|
label def destroy
|
|
decr_destroy
|
|
|
|
resolve_page
|
|
delete_runpod_pod
|
|
strand.children.each { it.destroy }
|
|
inference_endpoint.load_balancer.evacuate_vm(vm)
|
|
inference_endpoint.load_balancer.remove_vm(vm)
|
|
vm.incr_destroy
|
|
inference_endpoint_replica.destroy
|
|
|
|
pop "inference endpoint replica is deleted"
|
|
end
|
|
|
|
label def unavailable
|
|
if available?
|
|
resolve_page
|
|
hop_wait
|
|
end
|
|
|
|
create_page unless inference_endpoint.maintenance_set?
|
|
nap 30
|
|
end
|
|
|
|
def available?
|
|
load_balancer_vm_port.reload.state == "up"
|
|
end
|
|
|
|
def create_page
|
|
extra_data = {
|
|
inference_endpoint_ubid: inference_endpoint.ubid,
|
|
inference_endpoint_is_public: inference_endpoint.is_public,
|
|
inference_endpoint_location: inference_endpoint.location.name,
|
|
inference_endpoint_name: inference_endpoint.name,
|
|
inference_endpoint_model_name: inference_endpoint.model_name,
|
|
inference_endpoint_replica_count: inference_endpoint.replica_count,
|
|
load_balancer_ubid: inference_endpoint.load_balancer.ubid,
|
|
private_subnet_ubid: inference_endpoint.load_balancer.private_subnet.ubid,
|
|
replica_ubid: inference_endpoint_replica.ubid,
|
|
vm_ubid: vm.ubid,
|
|
vm_ip: vm.sshable.host,
|
|
vm_host_ubid: vm.vm_host.ubid,
|
|
vm_host_ip: vm.vm_host.sshable.host
|
|
}
|
|
Prog::PageNexus.assemble("Replica #{inference_endpoint_replica.ubid.to_s[0..7]} of inference endpoint #{inference_endpoint.name} is unavailable",
|
|
["InferenceEndpointReplicaUnavailable", inference_endpoint_replica.ubid],
|
|
inference_endpoint_replica.ubid, severity: "warning", extra_data:)
|
|
end
|
|
|
|
def resolve_page
|
|
Page.from_tag_parts("InferenceEndpointReplicaUnavailable", inference_endpoint_replica.ubid)&.incr_resolve
|
|
end
|
|
|
|
# pushes latest config to inference gateway and collects billing information
|
|
def ping_gateway
|
|
api_key_ds = DB[:api_key]
|
|
.where(owner_table: "project")
|
|
.where(used_for: "inference_endpoint")
|
|
.where(is_valid: true)
|
|
.where(owner_id: Sequel[:project][:id])
|
|
.exists
|
|
|
|
eligible_projects_ds = Project.where(api_key_ds)
|
|
free_quota_exhausted_projects_ds = FreeQuota.get_exhausted_projects("inference-tokens")
|
|
eligible_projects_ds = eligible_projects_ds.where(id: inference_endpoint.project.id) unless inference_endpoint.is_public
|
|
valid_payment_method_ds = DB[:payment_method]
|
|
.where(fraud: false)
|
|
.select_group(:billing_info_id)
|
|
.select_append { Sequel.as(Sequel.lit("1"), :valid_payment_method) }
|
|
eligible_projects_ds = eligible_projects_ds
|
|
.left_outer_join(valid_payment_method_ds, [:billing_info_id])
|
|
.exclude(valid_payment_method: nil, credit: 0.0, id: free_quota_exhausted_projects_ds)
|
|
|
|
eligible_projects = eligible_projects_ds.all
|
|
.select(&:active?)
|
|
.map do
|
|
{
|
|
ubid: it.ubid,
|
|
api_keys: it.api_keys.select { |k| k.used_for == "inference_endpoint" && k.is_valid }.map { |k| Digest::SHA2.hexdigest(k.key) },
|
|
quota_rps: inference_endpoint.max_project_rps,
|
|
quota_tps: inference_endpoint.max_project_tps
|
|
}
|
|
end
|
|
|
|
body = {
|
|
replica_ubid: inference_endpoint_replica.ubid,
|
|
public_endpoint: inference_endpoint.is_public,
|
|
projects: eligible_projects
|
|
}
|
|
|
|
resp = vm.sshable.cmd("sudo curl -m 10 --no-progress-meter -H \"Content-Type: application/json\" -X POST --data-binary @- --unix-socket /ie/workdir/inference-gateway.clover.sock http://localhost/control", stdin: body.to_json)
|
|
project_usage = JSON.parse(resp)["projects"]
|
|
Clog.emit("Successfully pinged inference gateway.") { {inference_endpoint: inference_endpoint.ubid, replica: inference_endpoint_replica.ubid, project_usage: project_usage} }
|
|
update_billing_records(project_usage, "input", "prompt_token_count")
|
|
update_billing_records(project_usage, "output", "completion_token_count")
|
|
end
|
|
|
|
def update_billing_records(project_usage, token_type, usage_key)
|
|
resource_family = "#{inference_endpoint.model_name}-#{token_type}"
|
|
rate = BillingRate.from_resource_properties("InferenceTokens", resource_family, "global")
|
|
return if rate["unit_price"].zero?
|
|
rate_id = rate["id"]
|
|
begin_time = Time.now.to_date.to_time
|
|
end_time = begin_time + 24 * 60 * 60
|
|
|
|
project_usage.each do |usage|
|
|
tokens = usage[usage_key]
|
|
next if tokens.zero?
|
|
project = Project[id: UBID.to_uuid(usage["ubid"])]
|
|
|
|
begin
|
|
today_record = BillingRecord
|
|
.where(project_id: project.id, resource_id: inference_endpoint.id, billing_rate_id: rate_id)
|
|
.where { Sequel.pg_range(it.span).overlaps(Sequel.pg_range(begin_time...end_time)) }
|
|
.first
|
|
|
|
if today_record
|
|
today_record.amount = Sequel[:amount] + tokens
|
|
today_record.save_changes(validate: false)
|
|
else
|
|
BillingRecord.create(
|
|
project_id: project.id,
|
|
resource_id: inference_endpoint.id,
|
|
resource_name: "#{resource_family} #{begin_time.strftime("%Y-%m-%d")}",
|
|
billing_rate_id: rate_id,
|
|
span: Sequel.pg_range(begin_time...end_time),
|
|
amount: tokens
|
|
)
|
|
end
|
|
rescue Sequel::Error => ex
|
|
Clog.emit("Failed to update billing record") { {billing_record_update_error: {project_ubid: project.ubid, model_name: inference_endpoint.model_name, replica_ubid: inference_endpoint_replica.ubid, tokens: tokens, exception: Util.exception_to_hash(ex)}} }
|
|
end
|
|
end
|
|
end
|
|
|
|
def engine_start_cmd
|
|
case inference_endpoint.engine
|
|
when "vllm"
|
|
env = (inference_endpoint.gpu_count == 0) ? "vllm-cpu" : "vllm"
|
|
"/opt/miniconda/envs/#{env}/bin/vllm serve /ie/models/model --served-model-name #{inference_endpoint.model_name} --disable-log-requests --host 127.0.0.1 #{inference_endpoint.engine_params}"
|
|
when "runpod"
|
|
"ssh -N -L 8000:localhost:8000 root@#{inference_endpoint_replica.external_state["ip"]} -p #{inference_endpoint_replica.external_state["port"]} -i /ie/workdir/.ssh/runpod -o UserKnownHostsFile=/ie/workdir/.ssh/known_hosts -o StrictHostKeyChecking=accept-new"
|
|
else
|
|
fail "BUG: unsupported inference engine"
|
|
end
|
|
end
|
|
|
|
def create_runpod_pod
|
|
response = Excon.post("https://api.runpod.io/graphql",
|
|
headers: {"content-type" => "application/json", "authorization" => "Bearer #{Config.runpod_api_key}"},
|
|
body: {"query" => "query Pods { myself { pods { id name runtime { ports { ip isIpPublic privatePort publicPort type } } } } }"}.to_json,
|
|
expects: 200)
|
|
|
|
pods = JSON.parse(response.body)["data"]["myself"]["pods"]
|
|
pod = pods.find { |pod| pod["name"] == inference_endpoint_replica.ubid }
|
|
|
|
return pod["id"] if pod
|
|
|
|
ssh_keys = vm.sshable.cmd(<<-CMD) + Config.operator_ssh_public_keys
|
|
if ! sudo test -f /ie/workdir/.ssh/runpod; then
|
|
sudo -u ie mkdir -p /ie/workdir/.ssh
|
|
sudo -u ie ssh-keygen -t ed25519 -C #{inference_endpoint_replica.ubid}@ubicloud.com -f /ie/workdir/.ssh/runpod -N '' -q
|
|
fi
|
|
sudo cat /ie/workdir/.ssh/runpod.pub
|
|
CMD
|
|
|
|
vllm_params = "--served-model-name #{inference_endpoint.model_name} --disable-log-requests --host 127.0.0.1 #{inference_endpoint.engine_params}"
|
|
|
|
config = inference_endpoint.external_config
|
|
graphql_query = <<~GRAPHQL
|
|
mutation {
|
|
podFindAndDeployOnDemand(
|
|
input: {
|
|
cloudType: ALL
|
|
dataCenterId: "#{config["data_center"]}"
|
|
gpuCount: #{config["gpu_count"]}
|
|
gpuTypeId: "#{config["gpu_type"]}"
|
|
containerDiskInGb: #{config["disk_gib"]}
|
|
minVcpuCount: #{config["min_vcpu_count"]}
|
|
minMemoryInGb: #{config["min_memory_gib"]}
|
|
imageName: "#{config["image_name"]}"
|
|
name: "#{inference_endpoint_replica.ubid}"
|
|
volumeInGb: 0
|
|
ports: "22/tcp"
|
|
env: [
|
|
{ key: "HF_TOKEN", value: "#{Config.huggingface_token}" },
|
|
{ key: "HF_HUB_ENABLE_HF_TRANSFER", value: "1"},
|
|
{ key: "MODEL_PATH", value: "/model"},
|
|
{ key: "MODEL_NAME_HF", value: "#{config["model_name_hf"]}"},
|
|
{ key: "VLLM_PARAMS", value: "#{vllm_params}"},
|
|
{ key: "SSH_KEYS", value: "#{ssh_keys.gsub("\n", "\\n")}" }
|
|
]
|
|
}
|
|
) {
|
|
id
|
|
imageName
|
|
env
|
|
machineId
|
|
machine {
|
|
podHostId
|
|
}
|
|
}
|
|
}
|
|
GRAPHQL
|
|
|
|
response = Excon.post("https://api.runpod.io/graphql",
|
|
headers: {"content-type" => "application/json", "authorization" => "Bearer #{Config.runpod_api_key}"},
|
|
body: {"query" => graphql_query}.to_json,
|
|
expects: 200)
|
|
|
|
JSON.parse(response.body)["data"]["podFindAndDeployOnDemand"]["id"]
|
|
end
|
|
|
|
def get_runpod_pod
|
|
pod_id = inference_endpoint_replica.external_state.fetch("pod_id")
|
|
response = Excon.post("https://api.runpod.io/graphql",
|
|
headers: {"content-type" => "application/json", "authorization" => "Bearer #{Config.runpod_api_key}"},
|
|
body: {"query" => "query Pod { pod(input: {podId: \"#{pod_id}\"}) { id name runtime { ports { ip isIpPublic privatePort publicPort type } } } }"}.to_json,
|
|
expects: 200)
|
|
|
|
pod = JSON.parse(response.body)["data"]["pod"]
|
|
fail "BUG: pod not found" unless pod
|
|
fail "BUG: unexpected pod id" unless pod_id == pod["id"]
|
|
|
|
port = pod["runtime"]["ports"].find { |port| port["type"] == "tcp" && port["isIpPublic"] }
|
|
|
|
{
|
|
pod_id: pod["id"],
|
|
ip: port&.fetch("ip"),
|
|
port: port&.fetch("publicPort")
|
|
}
|
|
end
|
|
|
|
def delete_runpod_pod
|
|
return unless (pod_id = inference_endpoint_replica.external_state["pod_id"])
|
|
Excon.post("https://api.runpod.io/graphql",
|
|
headers: {"content-type" => "application/json", "authorization" => "Bearer #{Config.runpod_api_key}"},
|
|
body: {"query" => "mutation { podTerminate(input: {podId: \"#{pod_id}\"}) }"}.to_json,
|
|
expects: 200)
|
|
inference_endpoint_replica.update(external_state: "{}")
|
|
end
|
|
end
|