Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,24 @@ func main() {
Usage: "Do not delete container on exit",
Destination: &args.Docker.KeepContainer,
},
&cli.BoolFlag{
Name: "privileged",
Usage: "Give extended privileges to the container",
Destination: &args.Docker.Privileged,
},
&cli.StringFlag{
Name: "ssh-key",
Usage: "Public SSH key",
Required: true,
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"},
},
&cli.StringFlag{
Name: "pjrt-device",
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
Destination: &args.Docker.PJRTDevice,
EnvVars: []string{"PJRT_DEVICE"},
},
&cli.BoolFlag{
Name: "service",
Usage: "Start as a service",
Expand Down
45 changes: 45 additions & 0 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
Expand Down Expand Up @@ -205,6 +206,15 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
"DSTACK_GPUS_NUM": strconv.Itoa(gpus_num),
}

// Call buildLDLibraryPathEnv and update jobEnvs if no error occurs
newLDPath, err := buildLDLibraryPathEnv()
if err != nil {
log.Info(ctx, "Continuing without updating LD_LIBRARY_PATH")
} else {
jobEnvs["LD_LIBRARY_PATH"] = newLDPath
log.Info(ctx, "New LD_LIBRARY_PATH set", newLDPath)
}

cmd := exec.CommandContext(ctx, ex.jobSpec.Commands[0], ex.jobSpec.Commands[1:]...)
cmd.Env = makeEnv(ex.homeDir, jobEnvs, ex.jobSpec.Env, ex.secrets)
cmd.Cancel = func() error {
Expand Down Expand Up @@ -291,3 +301,38 @@ func isPtyError(err error) bool {
var e *os.PathError
return errors.As(err, &e) && e.Err == syscall.EIO
}

func buildLDLibraryPathEnv() (string, error) {
// Execute shell command to get Python prefix
cmd := exec.Command("bash", "-i", "-c", "python3-config --prefix")
output, err := cmd.Output()

if err != nil {
return "", fmt.Errorf("error executing command: %v", err)
}

// Extract and trim the prefix path
prefixPath := strings.TrimSpace(string(output))

// Check if the prefix path exists
if _, err := os.Stat(prefixPath); os.IsNotExist(err) {
return "", fmt.Errorf("python prefix path does not exist: %s", prefixPath)
}

// Construct the path to Python's shared libraries
sharedLibPath := fmt.Sprintf("%s/lib", prefixPath)

// Get current LD_LIBRARY_PATH
currentLDPath := os.Getenv("LD_LIBRARY_PATH")

// Append Python's shared library path if not already present
if !strings.Contains(currentLDPath, sharedLibPath) {
if currentLDPath == "" {
currentLDPath = sharedLibPath
} else {
currentLDPath = fmt.Sprintf("%s:%s", currentLDPath, sharedLibPath)
}
}

return currentLDPath, nil
}
16 changes: 16 additions & 0 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,21 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str
return "", tracerr.Wrap(err)
}

//Set the environment variables
envVars := []string{}
if dockerParams.DockerPJRTDevice() != "" {
envVars = append(envVars, fmt.Sprintf("PJRT_DEVICE=%s", dockerParams.DockerPJRTDevice()))
}

containerConfig := &container.Config{
Image: taskConfig.ImageName,
Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")},
Entrypoint: []string{"/bin/sh", "-c"},
ExposedPorts: exposePorts(dockerParams.DockerPorts()...),
Env: envVars,
}
hostConfig := &container.HostConfig{
Privileged: dockerParams.DockerPrivileged(),
NetworkMode: getNetworkMode(),
PortBindings: bindPorts(dockerParams.DockerPorts()...),
PublishAllPorts: true,
Expand Down Expand Up @@ -426,6 +434,14 @@ func (c CLIArgs) DockerKeepContainer() bool {
return c.Docker.KeepContainer
}

func (c CLIArgs) DockerPrivileged() bool {
return c.Docker.Privileged
}

func (c CLIArgs) DockerPJRTDevice() string {
return c.Docker.PJRTDevice
}

func (c CLIArgs) DockerShellCommands(publicKeys []string) []string {
concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys
if len(publicKeys) > 0 {
Expand Down
8 changes: 8 additions & 0 deletions runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func (c *dockerParametersMock) DockerKeepContainer() bool {
return false
}

func (c *dockerParametersMock) DockerPrivileged() bool {
return false
}

func (c *dockerParametersMock) DockerPJRTDevice() string {
return ""
}

func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string {
userPublicKey := c.publicSSHKey
if len(publicKeys) > 0 {
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
)

type DockerParameters interface {
DockerPrivileged() bool
DockerKeepContainer() bool
DockerShellCommands([]string) []string
DockerMounts(string) ([]mount.Mount, error)
DockerPorts() []int
MakeRunnerDir() (string, error)
DockerPJRTDevice() string
}

type CLIArgs struct {
Expand All @@ -38,6 +40,8 @@ type CLIArgs struct {
SSHPort int
KeepContainer bool
ConcatinatedPublicSSHKeys string
Privileged bool
PJRTDevice string
}
}

Expand Down
14 changes: 9 additions & 5 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

logger = get_logger(__name__)


DSTACK_WORKING_DIR = "/root/.dstack"


Expand Down Expand Up @@ -121,14 +120,16 @@ def get_shim_env(build: str, authorized_keys: List[str]) -> Dict[str, str]:
return envs


def get_shim_commands(authorized_keys: List[str]) -> List[str]:
def get_shim_commands(
authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None
) -> List[str]:
build = get_dstack_runner_version()
commands = get_shim_pre_start_commands(
build,
)
for k, v in get_shim_env(build, authorized_keys).items():
commands += [f'export "{k}={v}"']
commands += get_run_shim_script()
commands += get_run_shim_script(is_privileged, pjrt_device)
return commands


Expand Down Expand Up @@ -161,10 +162,13 @@ def get_shim_pre_start_commands(build: str) -> List[str]:
]


def get_run_shim_script() -> List[str]:
def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]:
dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev"
privileged_flag = "--privileged" if is_privileged else ""
pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else ""

return [
f"nohup dstack-shim {dev_flag} docker --keep-container >{DSTACK_WORKING_DIR}/shim.log 2>&1 &",
f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} {pjrt_device_env} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &",
]


Expand Down
8 changes: 5 additions & 3 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]:


def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
commands = get_shim_commands(authorized_keys=authorized_keys)
commands = get_shim_commands(
authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU"
)
startup_script = " ".join([" && ".join(commands)])
startup_script = "#! /bin/bash\n" + startup_script
return startup_script
Expand All @@ -469,9 +471,9 @@ def _is_pod(instance_name: str) -> bool:
tensor_cores = int(tensor_cores)
except ValueError:
raise ValueError(f"Invalid number in tpu tensor cores: {tensor_cores}")
if version in ["v2", "v3"]:
if version in ["v2", "v3", "v5p", "v5litepod"]:
return tensor_cores > 8
elif version in ["v4", "v5p", "v5litepod"]:
elif version == "v4":
return True
else:
raise ValueError(f"Unknown TPU version: {version}")